tailssh_test.go 25 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build linux || darwin
  4. package tailssh
  5. import (
  6. "bytes"
  7. "context"
  8. "crypto/ed25519"
  9. "crypto/rand"
  10. "crypto/sha256"
  11. "encoding/json"
  12. "errors"
  13. "fmt"
  14. "io"
  15. "io/ioutil"
  16. "net"
  17. "net/http"
  18. "net/http/httptest"
  19. "net/netip"
  20. "os"
  21. "os/exec"
  22. "os/user"
  23. "reflect"
  24. "runtime"
  25. "strings"
  26. "sync"
  27. "sync/atomic"
  28. "testing"
  29. "time"
  30. gossh "github.com/tailscale/golang-x-crypto/ssh"
  31. "tailscale.com/ipn/ipnlocal"
  32. "tailscale.com/ipn/store/mem"
  33. "tailscale.com/net/memnet"
  34. "tailscale.com/net/tsdial"
  35. "tailscale.com/tailcfg"
  36. "tailscale.com/tempfork/gliderlabs/ssh"
  37. "tailscale.com/tstest"
  38. "tailscale.com/types/logger"
  39. "tailscale.com/types/logid"
  40. "tailscale.com/types/netmap"
  41. "tailscale.com/util/cibuild"
  42. "tailscale.com/util/lineread"
  43. "tailscale.com/util/must"
  44. "tailscale.com/version/distro"
  45. "tailscale.com/wgengine"
  46. )
  47. func TestMatchRule(t *testing.T) {
  48. someAction := new(tailcfg.SSHAction)
  49. tests := []struct {
  50. name string
  51. rule *tailcfg.SSHRule
  52. ci *sshConnInfo
  53. wantErr error
  54. wantUser string
  55. }{
  56. {
  57. name: "invalid-conn",
  58. rule: &tailcfg.SSHRule{
  59. Action: someAction,
  60. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  61. SSHUsers: map[string]string{
  62. "*": "ubuntu",
  63. },
  64. },
  65. wantErr: errInvalidConn,
  66. },
  67. {
  68. name: "nil-rule",
  69. ci: &sshConnInfo{},
  70. rule: nil,
  71. wantErr: errNilRule,
  72. },
  73. {
  74. name: "nil-action",
  75. ci: &sshConnInfo{},
  76. rule: &tailcfg.SSHRule{},
  77. wantErr: errNilAction,
  78. },
  79. {
  80. name: "expired",
  81. rule: &tailcfg.SSHRule{
  82. Action: someAction,
  83. RuleExpires: timePtr(time.Unix(100, 0)),
  84. },
  85. ci: &sshConnInfo{},
  86. wantErr: errRuleExpired,
  87. },
  88. {
  89. name: "no-principal",
  90. rule: &tailcfg.SSHRule{
  91. Action: someAction,
  92. SSHUsers: map[string]string{
  93. "*": "ubuntu",
  94. }},
  95. ci: &sshConnInfo{},
  96. wantErr: errPrincipalMatch,
  97. },
  98. {
  99. name: "no-user-match",
  100. rule: &tailcfg.SSHRule{
  101. Action: someAction,
  102. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  103. },
  104. ci: &sshConnInfo{sshUser: "alice"},
  105. wantErr: errUserMatch,
  106. },
  107. {
  108. name: "ok-wildcard",
  109. rule: &tailcfg.SSHRule{
  110. Action: someAction,
  111. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  112. SSHUsers: map[string]string{
  113. "*": "ubuntu",
  114. },
  115. },
  116. ci: &sshConnInfo{sshUser: "alice"},
  117. wantUser: "ubuntu",
  118. },
  119. {
  120. name: "ok-wildcard-and-nil-principal",
  121. rule: &tailcfg.SSHRule{
  122. Action: someAction,
  123. Principals: []*tailcfg.SSHPrincipal{
  124. nil, // don't crash on this
  125. {Any: true},
  126. },
  127. SSHUsers: map[string]string{
  128. "*": "ubuntu",
  129. },
  130. },
  131. ci: &sshConnInfo{sshUser: "alice"},
  132. wantUser: "ubuntu",
  133. },
  134. {
  135. name: "ok-exact",
  136. rule: &tailcfg.SSHRule{
  137. Action: someAction,
  138. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  139. SSHUsers: map[string]string{
  140. "*": "ubuntu",
  141. "alice": "thealice",
  142. },
  143. },
  144. ci: &sshConnInfo{sshUser: "alice"},
  145. wantUser: "thealice",
  146. },
  147. {
  148. name: "no-users-for-reject",
  149. rule: &tailcfg.SSHRule{
  150. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  151. Action: &tailcfg.SSHAction{Reject: true},
  152. },
  153. ci: &sshConnInfo{sshUser: "alice"},
  154. },
  155. {
  156. name: "match-principal-node-ip",
  157. rule: &tailcfg.SSHRule{
  158. Action: someAction,
  159. Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}},
  160. SSHUsers: map[string]string{"*": "ubuntu"},
  161. },
  162. ci: &sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")},
  163. wantUser: "ubuntu",
  164. },
  165. {
  166. name: "match-principal-node-id",
  167. rule: &tailcfg.SSHRule{
  168. Action: someAction,
  169. Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}},
  170. SSHUsers: map[string]string{"*": "ubuntu"},
  171. },
  172. ci: &sshConnInfo{node: &tailcfg.Node{StableID: "some-node-ID"}},
  173. wantUser: "ubuntu",
  174. },
  175. {
  176. name: "match-principal-userlogin",
  177. rule: &tailcfg.SSHRule{
  178. Action: someAction,
  179. Principals: []*tailcfg.SSHPrincipal{{UserLogin: "[email protected]"}},
  180. SSHUsers: map[string]string{"*": "ubuntu"},
  181. },
  182. ci: &sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "[email protected]"}},
  183. wantUser: "ubuntu",
  184. },
  185. {
  186. name: "ssh-user-equal",
  187. rule: &tailcfg.SSHRule{
  188. Action: someAction,
  189. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  190. SSHUsers: map[string]string{
  191. "*": "=",
  192. },
  193. },
  194. ci: &sshConnInfo{sshUser: "alice"},
  195. wantUser: "alice",
  196. },
  197. }
  198. for _, tt := range tests {
  199. t.Run(tt.name, func(t *testing.T) {
  200. c := &conn{
  201. info: tt.ci,
  202. srv: &server{logf: t.Logf},
  203. }
  204. got, gotUser, err := c.matchRule(tt.rule, nil)
  205. if err != tt.wantErr {
  206. t.Errorf("err = %v; want %v", err, tt.wantErr)
  207. }
  208. if gotUser != tt.wantUser {
  209. t.Errorf("user = %q; want %q", gotUser, tt.wantUser)
  210. }
  211. if err == nil && got == nil {
  212. t.Errorf("expected non-nil action on success")
  213. }
  214. })
  215. }
  216. }
  217. func timePtr(t time.Time) *time.Time { return &t }
  218. // localState implements ipnLocalBackend for testing.
  219. type localState struct {
  220. sshEnabled bool
  221. matchingRule *tailcfg.SSHRule
  222. // serverActions is a map of the action name to the action.
  223. // It is served for paths like https://unused/ssh-action/<action-name>.
  224. // The action name is the last part of the action URL.
  225. serverActions map[string]*tailcfg.SSHAction
  226. }
  227. var (
  228. currentUser = os.Getenv("USER") // Use the current user for the test.
  229. testSigner gossh.Signer
  230. testSignerOnce sync.Once
  231. )
  232. func (ts *localState) Dialer() *tsdial.Dialer {
  233. return nil
  234. }
  235. func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) {
  236. testSignerOnce.Do(func() {
  237. _, priv, err := ed25519.GenerateKey(rand.Reader)
  238. if err != nil {
  239. panic(err)
  240. }
  241. s, err := gossh.NewSignerFromSigner(priv)
  242. if err != nil {
  243. panic(err)
  244. }
  245. testSigner = s
  246. })
  247. return []gossh.Signer{testSigner}, nil
  248. }
  249. func (ts *localState) ShouldRunSSH() bool {
  250. return ts.sshEnabled
  251. }
  252. func (ts *localState) NetMap() *netmap.NetworkMap {
  253. var policy *tailcfg.SSHPolicy
  254. if ts.matchingRule != nil {
  255. policy = &tailcfg.SSHPolicy{
  256. Rules: []*tailcfg.SSHRule{
  257. ts.matchingRule,
  258. },
  259. }
  260. }
  261. return &netmap.NetworkMap{
  262. SelfNode: &tailcfg.Node{
  263. ID: 1,
  264. },
  265. SSHPolicy: policy,
  266. }
  267. }
  268. func (ts *localState) WhoIs(ipp netip.AddrPort) (n *tailcfg.Node, u tailcfg.UserProfile, ok bool) {
  269. return &tailcfg.Node{
  270. ID: 2,
  271. StableID: "peer-id",
  272. }, tailcfg.UserProfile{
  273. LoginName: "peer",
  274. }, true
  275. }
  276. func (ts *localState) DoNoiseRequest(req *http.Request) (*http.Response, error) {
  277. rec := httptest.NewRecorder()
  278. k, ok := strings.CutPrefix(req.URL.Path, "/ssh-action/")
  279. if !ok {
  280. rec.WriteHeader(http.StatusNotFound)
  281. }
  282. a, ok := ts.serverActions[k]
  283. if !ok {
  284. rec.WriteHeader(http.StatusNotFound)
  285. return rec.Result(), nil
  286. }
  287. rec.WriteHeader(http.StatusOK)
  288. if err := json.NewEncoder(rec).Encode(a); err != nil {
  289. return nil, err
  290. }
  291. return rec.Result(), nil
  292. }
  293. func (ts *localState) TailscaleVarRoot() string {
  294. return ""
  295. }
  296. func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule {
  297. return &tailcfg.SSHRule{
  298. SSHUsers: map[string]string{
  299. "*": currentUser,
  300. },
  301. Action: action,
  302. Principals: []*tailcfg.SSHPrincipal{
  303. {
  304. Any: true,
  305. },
  306. },
  307. }
  308. }
  309. func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
  310. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  311. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  312. }
  313. var handler http.HandlerFunc
  314. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  315. handler(w, r)
  316. }))
  317. defer recordingServer.Close()
  318. s := &server{
  319. logf: t.Logf,
  320. httpc: recordingServer.Client(),
  321. lb: &localState{
  322. sshEnabled: true,
  323. matchingRule: newSSHRule(
  324. &tailcfg.SSHAction{
  325. Accept: true,
  326. Recorders: []netip.AddrPort{
  327. netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
  328. },
  329. },
  330. ),
  331. },
  332. }
  333. defer s.Shutdown()
  334. const sshUser = "alice"
  335. cfg := &gossh.ClientConfig{
  336. User: sshUser,
  337. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  338. }
  339. tests := []struct {
  340. name string
  341. handler func(w http.ResponseWriter, r *http.Request)
  342. sshCommand string
  343. wantClientOutput string
  344. }{
  345. {
  346. name: "upload-denied",
  347. handler: func(w http.ResponseWriter, r *http.Request) {
  348. w.WriteHeader(http.StatusForbidden)
  349. },
  350. sshCommand: "echo hello",
  351. wantClientOutput: "recording: server responded with 403 Forbidden\r\n",
  352. },
  353. {
  354. name: "upload-fails-after-starting",
  355. handler: func(w http.ResponseWriter, r *http.Request) {
  356. r.Body.Read(make([]byte, 1))
  357. time.Sleep(100 * time.Millisecond)
  358. w.WriteHeader(http.StatusInternalServerError)
  359. },
  360. sshCommand: "echo hello && sleep 1 && echo world",
  361. wantClientOutput: "hello\n\r\n\r\nrecording server responded with: 500 Internal Server Error\r\n\r\n",
  362. },
  363. }
  364. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  365. for _, tt := range tests {
  366. t.Run(tt.name, func(t *testing.T) {
  367. tstest.Replace(t, &handler, tt.handler)
  368. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  369. var wg sync.WaitGroup
  370. wg.Add(1)
  371. go func() {
  372. defer wg.Done()
  373. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  374. if err != nil {
  375. t.Errorf("client: %v", err)
  376. return
  377. }
  378. client := gossh.NewClient(c, chans, reqs)
  379. defer client.Close()
  380. session, err := client.NewSession()
  381. if err != nil {
  382. t.Errorf("client: %v", err)
  383. return
  384. }
  385. defer session.Close()
  386. t.Logf("client established session")
  387. got, err := session.CombinedOutput(tt.sshCommand)
  388. if err != nil {
  389. t.Logf("client got: %q: %v", got, err)
  390. } else {
  391. t.Errorf("client did not get kicked out: %q", got)
  392. }
  393. if string(got) != tt.wantClientOutput {
  394. t.Errorf("client got %q, want %q", got, tt.wantClientOutput)
  395. }
  396. }()
  397. if err := s.HandleSSHConn(dc); err != nil {
  398. t.Errorf("unexpected error: %v", err)
  399. }
  400. wg.Wait()
  401. })
  402. }
  403. }
  404. // TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
  405. // when the client is not interactive (i.e. no PTY).
  406. // It starts a local SSH server and a recording server. The recording server
  407. // records the SSH session and returns it to the test.
  408. // The test then verifies that the recording has a valid CastHeader, it does not
  409. // validate the contents of the recording.
  410. func TestSSHRecordingNonInteractive(t *testing.T) {
  411. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  412. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  413. }
  414. var recording []byte
  415. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  416. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  417. defer cancel()
  418. var err error
  419. recording, err = ioutil.ReadAll(r.Body)
  420. if err != nil {
  421. t.Error(err)
  422. return
  423. }
  424. }))
  425. defer recordingServer.Close()
  426. s := &server{
  427. logf: logger.Discard,
  428. httpc: recordingServer.Client(),
  429. lb: &localState{
  430. sshEnabled: true,
  431. matchingRule: newSSHRule(
  432. &tailcfg.SSHAction{
  433. Accept: true,
  434. Recorders: []netip.AddrPort{
  435. must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
  436. },
  437. },
  438. ),
  439. },
  440. }
  441. defer s.Shutdown()
  442. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  443. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  444. const sshUser = "alice"
  445. cfg := &gossh.ClientConfig{
  446. User: sshUser,
  447. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  448. }
  449. var wg sync.WaitGroup
  450. wg.Add(1)
  451. go func() {
  452. defer wg.Done()
  453. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  454. if err != nil {
  455. t.Errorf("client: %v", err)
  456. return
  457. }
  458. client := gossh.NewClient(c, chans, reqs)
  459. defer client.Close()
  460. session, err := client.NewSession()
  461. if err != nil {
  462. t.Errorf("client: %v", err)
  463. return
  464. }
  465. defer session.Close()
  466. t.Logf("client established session")
  467. _, err = session.CombinedOutput("echo Ran echo!")
  468. if err != nil {
  469. t.Errorf("client: %v", err)
  470. }
  471. }()
  472. if err := s.HandleSSHConn(dc); err != nil {
  473. t.Errorf("unexpected error: %v", err)
  474. }
  475. wg.Wait()
  476. <-ctx.Done() // wait for recording to finish
  477. var ch CastHeader
  478. if err := json.NewDecoder(bytes.NewReader(recording)).Decode(&ch); err != nil {
  479. t.Fatal(err)
  480. }
  481. if ch.SSHUser != sshUser {
  482. t.Errorf("SSHUser = %q; want %q", ch.SSHUser, sshUser)
  483. }
  484. if ch.Command != "echo Ran echo!" {
  485. t.Errorf("Command = %q; want %q", ch.Command, "echo Ran echo!")
  486. }
  487. }
  488. func TestSSHAuthFlow(t *testing.T) {
  489. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  490. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  491. }
  492. acceptRule := newSSHRule(&tailcfg.SSHAction{
  493. Accept: true,
  494. Message: "Welcome to Tailscale SSH!",
  495. })
  496. rejectRule := newSSHRule(&tailcfg.SSHAction{
  497. Reject: true,
  498. Message: "Go Away!",
  499. })
  500. tests := []struct {
  501. name string
  502. sshUser string // defaults to alice
  503. state *localState
  504. wantBanners []string
  505. usesPassword bool
  506. authErr bool
  507. }{
  508. {
  509. name: "no-policy",
  510. state: &localState{
  511. sshEnabled: true,
  512. },
  513. authErr: true,
  514. },
  515. {
  516. name: "accept",
  517. state: &localState{
  518. sshEnabled: true,
  519. matchingRule: acceptRule,
  520. },
  521. wantBanners: []string{"Welcome to Tailscale SSH!"},
  522. },
  523. {
  524. name: "reject",
  525. state: &localState{
  526. sshEnabled: true,
  527. matchingRule: rejectRule,
  528. },
  529. wantBanners: []string{"Go Away!"},
  530. authErr: true,
  531. },
  532. {
  533. name: "simple-check",
  534. state: &localState{
  535. sshEnabled: true,
  536. matchingRule: newSSHRule(&tailcfg.SSHAction{
  537. HoldAndDelegate: "https://unused/ssh-action/accept",
  538. }),
  539. serverActions: map[string]*tailcfg.SSHAction{
  540. "accept": acceptRule.Action,
  541. },
  542. },
  543. wantBanners: []string{"Welcome to Tailscale SSH!"},
  544. },
  545. {
  546. name: "multi-check",
  547. state: &localState{
  548. sshEnabled: true,
  549. matchingRule: newSSHRule(&tailcfg.SSHAction{
  550. Message: "First",
  551. HoldAndDelegate: "https://unused/ssh-action/check1",
  552. }),
  553. serverActions: map[string]*tailcfg.SSHAction{
  554. "check1": {
  555. Message: "url-here",
  556. HoldAndDelegate: "https://unused/ssh-action/check2",
  557. },
  558. "check2": acceptRule.Action,
  559. },
  560. },
  561. wantBanners: []string{"First", "url-here", "Welcome to Tailscale SSH!"},
  562. },
  563. {
  564. name: "check-reject",
  565. state: &localState{
  566. sshEnabled: true,
  567. matchingRule: newSSHRule(&tailcfg.SSHAction{
  568. Message: "First",
  569. HoldAndDelegate: "https://unused/ssh-action/reject",
  570. }),
  571. serverActions: map[string]*tailcfg.SSHAction{
  572. "reject": rejectRule.Action,
  573. },
  574. },
  575. wantBanners: []string{"First", "Go Away!"},
  576. authErr: true,
  577. },
  578. {
  579. name: "force-password-auth",
  580. sshUser: "alice+password",
  581. state: &localState{
  582. sshEnabled: true,
  583. matchingRule: acceptRule,
  584. },
  585. usesPassword: true,
  586. wantBanners: []string{"Welcome to Tailscale SSH!"},
  587. },
  588. }
  589. s := &server{
  590. logf: logger.Discard,
  591. }
  592. defer s.Shutdown()
  593. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  594. for _, tc := range tests {
  595. t.Run(tc.name, func(t *testing.T) {
  596. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  597. s.lb = tc.state
  598. sshUser := "alice"
  599. if tc.sshUser != "" {
  600. sshUser = tc.sshUser
  601. }
  602. var passwordUsed atomic.Bool
  603. cfg := &gossh.ClientConfig{
  604. User: sshUser,
  605. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  606. Auth: []gossh.AuthMethod{
  607. gossh.PasswordCallback(func() (secret string, err error) {
  608. if !tc.usesPassword {
  609. t.Error("unexpected use of PasswordCallback")
  610. return "", errors.New("unexpected use of PasswordCallback")
  611. }
  612. passwordUsed.Store(true)
  613. return "any-pass", nil
  614. }),
  615. },
  616. BannerCallback: func(message string) error {
  617. if len(tc.wantBanners) == 0 {
  618. t.Errorf("unexpected banner: %q", message)
  619. } else if message != tc.wantBanners[0] {
  620. t.Errorf("banner = %q; want %q", message, tc.wantBanners[0])
  621. } else {
  622. t.Logf("banner = %q", message)
  623. tc.wantBanners = tc.wantBanners[1:]
  624. }
  625. return nil
  626. },
  627. }
  628. var wg sync.WaitGroup
  629. wg.Add(1)
  630. go func() {
  631. defer wg.Done()
  632. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  633. if err != nil {
  634. if !tc.authErr {
  635. t.Errorf("client: %v", err)
  636. }
  637. return
  638. } else if tc.authErr {
  639. c.Close()
  640. t.Errorf("client: expected error, got nil")
  641. return
  642. }
  643. client := gossh.NewClient(c, chans, reqs)
  644. defer client.Close()
  645. session, err := client.NewSession()
  646. if err != nil {
  647. t.Errorf("client: %v", err)
  648. return
  649. }
  650. defer session.Close()
  651. _, err = session.CombinedOutput("echo Ran echo!")
  652. if err != nil {
  653. t.Errorf("client: %v", err)
  654. }
  655. }()
  656. if err := s.HandleSSHConn(dc); err != nil {
  657. t.Errorf("unexpected error: %v", err)
  658. }
  659. wg.Wait()
  660. if len(tc.wantBanners) > 0 {
  661. t.Errorf("missing banners: %v", tc.wantBanners)
  662. }
  663. })
  664. }
  665. }
  666. func TestSSH(t *testing.T) {
  667. var logf logger.Logf = t.Logf
  668. eng, err := wgengine.NewFakeUserspaceEngine(logf, 0)
  669. if err != nil {
  670. t.Fatal(err)
  671. }
  672. lb, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{},
  673. new(mem.Store),
  674. new(tsdial.Dialer),
  675. eng, 0)
  676. if err != nil {
  677. t.Fatal(err)
  678. }
  679. defer lb.Shutdown()
  680. dir := t.TempDir()
  681. lb.SetVarRoot(dir)
  682. srv := &server{
  683. lb: lb,
  684. logf: logf,
  685. }
  686. sc, err := srv.newConn()
  687. if err != nil {
  688. t.Fatal(err)
  689. }
  690. // Remove the auth checks for the test
  691. sc.insecureSkipTailscaleAuth = true
  692. u, err := user.Current()
  693. if err != nil {
  694. t.Fatal(err)
  695. }
  696. sc.localUser = u
  697. sc.info = &sshConnInfo{
  698. sshUser: "test",
  699. src: netip.MustParseAddrPort("1.2.3.4:32342"),
  700. dst: netip.MustParseAddrPort("1.2.3.5:22"),
  701. node: &tailcfg.Node{},
  702. uprof: tailcfg.UserProfile{},
  703. }
  704. sc.action0 = &tailcfg.SSHAction{Accept: true}
  705. sc.finalAction = sc.action0
  706. sc.Handler = func(s ssh.Session) {
  707. sc.newSSHSession(s).run()
  708. }
  709. ln, err := net.Listen("tcp4", "127.0.0.1:0")
  710. if err != nil {
  711. t.Fatal(err)
  712. }
  713. defer ln.Close()
  714. port := ln.Addr().(*net.TCPAddr).Port
  715. go func() {
  716. for {
  717. c, err := ln.Accept()
  718. if err != nil {
  719. if !errors.Is(err, net.ErrClosed) {
  720. t.Errorf("Accept: %v", err)
  721. }
  722. return
  723. }
  724. go sc.HandleConn(c)
  725. }
  726. }()
  727. execSSH := func(args ...string) *exec.Cmd {
  728. cmd := exec.Command("ssh",
  729. "-F",
  730. "none",
  731. "-v",
  732. "-p", fmt.Sprint(port),
  733. "-o", "StrictHostKeyChecking=no",
  734. "[email protected]")
  735. cmd.Args = append(cmd.Args, args...)
  736. return cmd
  737. }
  738. t.Run("env", func(t *testing.T) {
  739. if cibuild.On() {
  740. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  741. }
  742. cmd := execSSH("LANG=foo env")
  743. cmd.Env = append(os.Environ(), "LOCAL_ENV=bar")
  744. got, err := cmd.CombinedOutput()
  745. if err != nil {
  746. t.Fatal(err, string(got))
  747. }
  748. m := parseEnv(got)
  749. if got := m["USER"]; got == "" || got != u.Username {
  750. t.Errorf("USER = %q; want %q", got, u.Username)
  751. }
  752. if got := m["HOME"]; got == "" || got != u.HomeDir {
  753. t.Errorf("HOME = %q; want %q", got, u.HomeDir)
  754. }
  755. if got := m["PWD"]; got == "" || got != u.HomeDir {
  756. t.Errorf("PWD = %q; want %q", got, u.HomeDir)
  757. }
  758. if got := m["SHELL"]; got == "" {
  759. t.Errorf("no SHELL")
  760. }
  761. if got, want := m["LANG"], "foo"; got != want {
  762. t.Errorf("LANG = %q; want %q", got, want)
  763. }
  764. if got := m["LOCAL_ENV"]; got != "" {
  765. t.Errorf("LOCAL_ENV leaked over ssh: %v", got)
  766. }
  767. t.Logf("got: %+v", m)
  768. })
  769. t.Run("stdout_stderr", func(t *testing.T) {
  770. cmd := execSSH("sh", "-c", "echo foo; echo bar >&2")
  771. var outBuf, errBuf bytes.Buffer
  772. cmd.Stdout = &outBuf
  773. cmd.Stderr = &errBuf
  774. if err := cmd.Run(); err != nil {
  775. t.Fatal(err)
  776. }
  777. t.Logf("Got: %q and %q", outBuf.Bytes(), errBuf.Bytes())
  778. // TODO: figure out why these aren't right. should be
  779. // "foo\n" and "bar\n", not "\n" and "bar\n".
  780. })
  781. t.Run("stdin", func(t *testing.T) {
  782. if cibuild.On() {
  783. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  784. }
  785. cmd := execSSH("cat")
  786. var outBuf bytes.Buffer
  787. cmd.Stdout = &outBuf
  788. const str = "foo\nbar\n"
  789. cmd.Stdin = strings.NewReader(str)
  790. if err := cmd.Run(); err != nil {
  791. t.Fatal(err)
  792. }
  793. if got := outBuf.String(); got != str {
  794. t.Errorf("got %q; want %q", got, str)
  795. }
  796. })
  797. }
  798. func parseEnv(out []byte) map[string]string {
  799. e := map[string]string{}
  800. lineread.Reader(bytes.NewReader(out), func(line []byte) error {
  801. i := bytes.IndexByte(line, '=')
  802. if i == -1 {
  803. return nil
  804. }
  805. e[string(line[:i])] = string(line[i+1:])
  806. return nil
  807. })
  808. return e
  809. }
  810. func TestPublicKeyFetching(t *testing.T) {
  811. var reqsTotal, reqsIfNoneMatchHit, reqsIfNoneMatchMiss int32
  812. ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  813. atomic.AddInt32((&reqsTotal), 1)
  814. etag := fmt.Sprintf("W/%q", sha256.Sum256([]byte(r.URL.Path)))
  815. w.Header().Set("Etag", etag)
  816. if v := r.Header.Get("If-None-Match"); v != "" {
  817. if v == etag {
  818. atomic.AddInt32(&reqsIfNoneMatchHit, 1)
  819. w.WriteHeader(304)
  820. return
  821. }
  822. atomic.AddInt32(&reqsIfNoneMatchMiss, 1)
  823. }
  824. io.WriteString(w, "foo\nbar\n"+string(r.URL.Path)+"\n")
  825. }))
  826. ts.StartTLS()
  827. defer ts.Close()
  828. keys := ts.URL
  829. clock := &tstest.Clock{}
  830. srv := &server{
  831. pubKeyHTTPClient: ts.Client(),
  832. timeNow: clock.Now,
  833. }
  834. for i := 0; i < 2; i++ {
  835. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  836. if err != nil {
  837. t.Fatal(err)
  838. }
  839. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  840. t.Errorf("got %q; want %q", got, want)
  841. }
  842. }
  843. if got, want := atomic.LoadInt32(&reqsTotal), int32(1); got != want {
  844. t.Errorf("got %d requests; want %d", got, want)
  845. }
  846. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(0); got != want {
  847. t.Errorf("got %d etag hits; want %d", got, want)
  848. }
  849. clock.Advance(5 * time.Minute)
  850. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  851. if err != nil {
  852. t.Fatal(err)
  853. }
  854. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  855. t.Errorf("got %q; want %q", got, want)
  856. }
  857. if got, want := atomic.LoadInt32(&reqsTotal), int32(2); got != want {
  858. t.Errorf("got %d requests; want %d", got, want)
  859. }
  860. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(1); got != want {
  861. t.Errorf("got %d etag hits; want %d", got, want)
  862. }
  863. if got, want := atomic.LoadInt32(&reqsIfNoneMatchMiss), int32(0); got != want {
  864. t.Errorf("got %d etag misses; want %d", got, want)
  865. }
  866. }
  867. func TestExpandPublicKeyURL(t *testing.T) {
  868. c := &conn{
  869. info: &sshConnInfo{
  870. uprof: tailcfg.UserProfile{
  871. LoginName: "[email protected]",
  872. },
  873. },
  874. }
  875. if got, want := c.expandPublicKeyURL("foo"), "foo"; got != want {
  876. t.Errorf("basic: got %q; want %q", got, want)
  877. }
  878. if got, want := c.expandPublicKeyURL("https://example.com/$LOGINNAME_LOCALPART.keys"), "https://example.com/bar.keys"; got != want {
  879. t.Errorf("localpart: got %q; want %q", got, want)
  880. }
  881. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/[email protected]"; got != want {
  882. t.Errorf("email: got %q; want %q", got, want)
  883. }
  884. c.info = new(sshConnInfo)
  885. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/keys?email="; got != want {
  886. t.Errorf("on empty: got %q; want %q", got, want)
  887. }
  888. }
  889. func TestAcceptEnvPair(t *testing.T) {
  890. tests := []struct {
  891. in string
  892. want bool
  893. }{
  894. {"TERM=x", true},
  895. {"term=x", false},
  896. {"TERM", false},
  897. {"LC_FOO=x", true},
  898. {"LD_PRELOAD=naah", false},
  899. {"TERM=screen-256color", true},
  900. }
  901. for _, tt := range tests {
  902. if got := acceptEnvPair(tt.in); got != tt.want {
  903. t.Errorf("for %q, got %v; want %v", tt.in, got, tt.want)
  904. }
  905. }
  906. }
  907. func TestPathFromPAMEnvLine(t *testing.T) {
  908. u := &user.User{Username: "foo", HomeDir: "/Homes/Foo"}
  909. tests := []struct {
  910. line string
  911. u *user.User
  912. want string
  913. }{
  914. {"", u, ""},
  915. {`PATH DEFAULT="/run/wrappers/bin:@{HOME}/.nix-profile/bin:/etc/profiles/per-user/@{PAM_USER}/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"`,
  916. u, "/run/wrappers/bin:/Homes/Foo/.nix-profile/bin:/etc/profiles/per-user/foo/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"},
  917. {`PATH DEFAULT="@{SOMETHING_ELSE}:nope:@{HOME}"`,
  918. u, ""},
  919. }
  920. for i, tt := range tests {
  921. got := pathFromPAMEnvLine([]byte(tt.line), tt.u)
  922. if got != tt.want {
  923. t.Errorf("%d. got %q; want %q", i, got, tt.want)
  924. }
  925. }
  926. }
  927. func TestExpandDefaultPathTmpl(t *testing.T) {
  928. u := &user.User{Username: "foo", HomeDir: "/Homes/Foo"}
  929. tests := []struct {
  930. t string
  931. u *user.User
  932. want string
  933. }{
  934. {"", u, ""},
  935. {`/run/wrappers/bin:@{HOME}/.nix-profile/bin:/etc/profiles/per-user/@{PAM_USER}/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin`,
  936. u, "/run/wrappers/bin:/Homes/Foo/.nix-profile/bin:/etc/profiles/per-user/foo/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"},
  937. {`@{SOMETHING_ELSE}:nope:@{HOME}`, u, ""},
  938. }
  939. for i, tt := range tests {
  940. got := expandDefaultPathTmpl(tt.t, tt.u)
  941. if got != tt.want {
  942. t.Errorf("%d. got %q; want %q", i, got, tt.want)
  943. }
  944. }
  945. }
  946. func TestPathFromPAMEnvLineOnNixOS(t *testing.T) {
  947. if runtime.GOOS != "linux" {
  948. t.Skip("skipping on non-linux")
  949. }
  950. if distro.Get() != distro.NixOS {
  951. t.Skip("skipping on non-NixOS")
  952. }
  953. u, err := user.Current()
  954. if err != nil {
  955. t.Fatal(err)
  956. }
  957. got := defaultPathForUserOnNixOS(u)
  958. if got == "" {
  959. x, err := os.ReadFile("/etc/pam/environment")
  960. t.Fatalf("no result. file was: err=%v, contents=%s", err, x)
  961. }
  962. t.Logf("success; got=%q", got)
  963. }