tailssh_integration_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build integrationtest
  4. // +build integrationtest
  5. package tailssh
  6. import (
  7. "bufio"
  8. "context"
  9. "crypto/ecdsa"
  10. "crypto/ed25519"
  11. "crypto/elliptic"
  12. "crypto/rand"
  13. "crypto/rsa"
  14. "crypto/x509"
  15. "encoding/pem"
  16. "fmt"
  17. "io"
  18. "log"
  19. "net"
  20. "net/http"
  21. "net/netip"
  22. "os"
  23. "os/exec"
  24. "runtime"
  25. "strings"
  26. "testing"
  27. "time"
  28. "github.com/bramvdbogaerde/go-scp"
  29. "github.com/google/go-cmp/cmp"
  30. "github.com/pkg/sftp"
  31. gossh "github.com/tailscale/golang-x-crypto/ssh"
  32. "golang.org/x/crypto/ssh"
  33. "tailscale.com/net/tsdial"
  34. "tailscale.com/tailcfg"
  35. "tailscale.com/types/key"
  36. "tailscale.com/types/netmap"
  37. "tailscale.com/util/set"
  38. )
  39. // This file contains integration tests of the SSH functionality. These tests
  40. // exercise everything except for the authentication logic.
  41. //
  42. // The tests make the following assumptions about the environment:
  43. //
  44. // - OS is one of MacOS or Linux
  45. // - Test is being run as root (e.g. go test -tags integrationtest -c . && sudo ./tailssh.test -test.run TestIntegration)
  46. // - TAILSCALED_PATH environment variable points at tailscaled binary
  47. // - User "testuser" exists
  48. // - "testuser" is in groups "groupone" and "grouptwo"
  49. func TestMain(m *testing.M) {
  50. // Create our log file.
  51. file, err := os.OpenFile("/tmp/tailscalessh.log", os.O_CREATE|os.O_WRONLY, 0666)
  52. if err != nil {
  53. log.Fatal(err)
  54. }
  55. file.Close()
  56. // Tail our log file.
  57. cmd := exec.Command("tail", "-F", "/tmp/tailscalessh.log")
  58. r, err := cmd.StdoutPipe()
  59. if err != nil {
  60. return
  61. }
  62. scanner := bufio.NewScanner(r)
  63. go func() {
  64. for scanner.Scan() {
  65. line := scanner.Text()
  66. log.Println(line)
  67. }
  68. }()
  69. err = cmd.Start()
  70. if err != nil {
  71. return
  72. }
  73. defer func() {
  74. // tail -f has a default sleep interval of 1 second, so it takes a
  75. // moment for it to finish reading our log file after we've terminated.
  76. // So, wait a bit to let it catch up.
  77. time.Sleep(2 * time.Second)
  78. }()
  79. m.Run()
  80. }
  81. func TestIntegrationSSH(t *testing.T) {
  82. debugTest.Store(true)
  83. t.Cleanup(func() {
  84. debugTest.Store(false)
  85. })
  86. homeDir := "/home/testuser"
  87. if runtime.GOOS == "darwin" {
  88. homeDir = "/Users/testuser"
  89. }
  90. tests := []struct {
  91. cmd string
  92. want []string
  93. forceV1Behavior bool
  94. skip bool
  95. }{
  96. {
  97. cmd: "id",
  98. want: []string{"testuser", "groupone", "grouptwo"},
  99. forceV1Behavior: false,
  100. },
  101. {
  102. cmd: "id",
  103. want: []string{"testuser", "groupone", "grouptwo"},
  104. forceV1Behavior: true,
  105. },
  106. {
  107. cmd: "pwd",
  108. want: []string{homeDir},
  109. skip: !fallbackToSUAvailable(),
  110. forceV1Behavior: false,
  111. },
  112. {
  113. cmd: "echo 'hello'",
  114. want: []string{"hello"},
  115. skip: !fallbackToSUAvailable(),
  116. forceV1Behavior: false,
  117. },
  118. }
  119. for _, test := range tests {
  120. if test.skip {
  121. continue
  122. }
  123. // run every test both without and with a shell
  124. for _, shell := range []bool{false, true} {
  125. shellQualifier := "no_shell"
  126. if shell {
  127. shellQualifier = "shell"
  128. }
  129. versionQualifier := "v2"
  130. if test.forceV1Behavior {
  131. versionQualifier = "v1"
  132. }
  133. t.Run(fmt.Sprintf("%s_%s_%s", test.cmd, shellQualifier, versionQualifier), func(t *testing.T) {
  134. s := testSession(t, test.forceV1Behavior)
  135. if shell {
  136. err := s.RequestPty("xterm", 40, 80, ssh.TerminalModes{
  137. ssh.ECHO: 1,
  138. ssh.TTY_OP_ISPEED: 14400,
  139. ssh.TTY_OP_OSPEED: 14400,
  140. })
  141. if err != nil {
  142. t.Fatalf("unable to request PTY: %s", err)
  143. }
  144. err = s.Shell()
  145. if err != nil {
  146. t.Fatalf("unable to request shell: %s", err)
  147. }
  148. // Read the shell prompt
  149. s.read()
  150. }
  151. got := s.run(t, test.cmd, shell)
  152. for _, want := range test.want {
  153. if !strings.Contains(got, want) {
  154. t.Errorf("%q does not contain %q", got, want)
  155. }
  156. }
  157. })
  158. }
  159. }
  160. }
  161. func TestIntegrationSFTP(t *testing.T) {
  162. debugTest.Store(true)
  163. t.Cleanup(func() {
  164. debugTest.Store(false)
  165. })
  166. for _, forceV1Behavior := range []bool{false, true} {
  167. name := "v2"
  168. if forceV1Behavior {
  169. name = "v1"
  170. }
  171. t.Run(name, func(t *testing.T) {
  172. filePath := "/home/testuser/sftptest.dat"
  173. if forceV1Behavior || !fallbackToSUAvailable() {
  174. filePath = "/tmp/sftptest.dat"
  175. }
  176. wantText := "hello world"
  177. cl := testClient(t, forceV1Behavior)
  178. scl, err := sftp.NewClient(cl)
  179. if err != nil {
  180. t.Fatalf("can't get sftp client: %s", err)
  181. }
  182. file, err := scl.Create(filePath)
  183. if err != nil {
  184. t.Fatalf("can't create file: %s", err)
  185. }
  186. _, err = file.Write([]byte(wantText))
  187. if err != nil {
  188. t.Fatalf("can't write to file: %s", err)
  189. }
  190. err = file.Close()
  191. if err != nil {
  192. t.Fatalf("can't close file: %s", err)
  193. }
  194. file, err = scl.OpenFile(filePath, os.O_RDONLY)
  195. if err != nil {
  196. t.Fatalf("can't open file: %s", err)
  197. }
  198. defer file.Close()
  199. gotText, err := io.ReadAll(file)
  200. if err != nil {
  201. t.Fatalf("can't read file: %s", err)
  202. }
  203. if diff := cmp.Diff(string(gotText), wantText); diff != "" {
  204. t.Fatalf("unexpected file contents (-got +want):\n%s", diff)
  205. }
  206. s := testSessionFor(t, cl)
  207. got := s.run(t, "ls -l "+filePath, false)
  208. if !strings.Contains(got, "testuser") {
  209. t.Fatalf("unexpected file owner user: %s", got)
  210. } else if !strings.Contains(got, "testuser") {
  211. t.Fatalf("unexpected file owner group: %s", got)
  212. }
  213. })
  214. }
  215. }
  216. func TestIntegrationSCP(t *testing.T) {
  217. debugTest.Store(true)
  218. t.Cleanup(func() {
  219. debugTest.Store(false)
  220. })
  221. for _, forceV1Behavior := range []bool{false, true} {
  222. name := "v2"
  223. if forceV1Behavior {
  224. name = "v1"
  225. }
  226. t.Run(name, func(t *testing.T) {
  227. filePath := "/home/testuser/scptest.dat"
  228. if !fallbackToSUAvailable() {
  229. filePath = "/tmp/scptest.dat"
  230. }
  231. wantText := "hello world"
  232. cl := testClient(t, forceV1Behavior)
  233. scl, err := scp.NewClientBySSH(cl)
  234. if err != nil {
  235. t.Fatalf("can't get sftp client: %s", err)
  236. }
  237. err = scl.Copy(context.Background(), strings.NewReader(wantText), filePath, "0644", int64(len(wantText)))
  238. if err != nil {
  239. t.Fatalf("can't create file: %s", err)
  240. }
  241. outfile, err := os.CreateTemp("", "")
  242. if err != nil {
  243. t.Fatalf("can't create temp file: %s", err)
  244. }
  245. err = scl.CopyFromRemote(context.Background(), outfile, filePath)
  246. if err != nil {
  247. t.Fatalf("can't copy file from remote: %s", err)
  248. }
  249. outfile.Close()
  250. gotText, err := os.ReadFile(outfile.Name())
  251. if err != nil {
  252. t.Fatalf("can't read file: %s", err)
  253. }
  254. if diff := cmp.Diff(string(gotText), wantText); diff != "" {
  255. t.Fatalf("unexpected file contents (-got +want):\n%s", diff)
  256. }
  257. s := testSessionFor(t, cl)
  258. got := s.run(t, "ls -l "+filePath, false)
  259. if !strings.Contains(got, "testuser") {
  260. t.Fatalf("unexpected file owner user: %s", got)
  261. } else if !strings.Contains(got, "testuser") {
  262. t.Fatalf("unexpected file owner group: %s", got)
  263. }
  264. })
  265. }
  266. }
  267. func fallbackToSUAvailable() bool {
  268. if runtime.GOOS != "linux" {
  269. return false
  270. }
  271. _, err := exec.LookPath("su")
  272. if err != nil {
  273. return false
  274. }
  275. // Some operating systems like Fedora seem to require login to be present
  276. // in order for su to work.
  277. _, err = exec.LookPath("login")
  278. return err == nil
  279. }
  280. type session struct {
  281. *ssh.Session
  282. stdin io.WriteCloser
  283. stdout io.ReadCloser
  284. stderr io.ReadCloser
  285. }
  286. func (s *session) run(t *testing.T, cmdString string, shell bool) string {
  287. t.Helper()
  288. if shell {
  289. _, err := s.stdin.Write([]byte(fmt.Sprintf("%s\n", cmdString)))
  290. if err != nil {
  291. t.Fatalf("unable to send command to shell: %s", err)
  292. }
  293. } else {
  294. err := s.Start(cmdString)
  295. if err != nil {
  296. t.Fatalf("unable to start command: %s", err)
  297. }
  298. }
  299. return s.read()
  300. }
  301. func (s *session) read() string {
  302. ch := make(chan []byte)
  303. go func() {
  304. for {
  305. b := make([]byte, 1)
  306. n, err := s.stdout.Read(b)
  307. if n > 0 {
  308. ch <- b
  309. }
  310. if err == io.EOF {
  311. return
  312. }
  313. }
  314. }()
  315. // Read first byte in blocking fashion.
  316. _got := <-ch
  317. // Read subsequent bytes in non-blocking fashion.
  318. readLoop:
  319. for {
  320. select {
  321. case b := <-ch:
  322. _got = append(_got, b...)
  323. case <-time.After(1 * time.Second):
  324. break readLoop
  325. }
  326. }
  327. return string(_got)
  328. }
  329. func testClient(t *testing.T, forceV1Behavior bool) *ssh.Client {
  330. t.Helper()
  331. username := "testuser"
  332. srv := &server{
  333. lb: &testBackend{localUser: username, forceV1Behavior: forceV1Behavior},
  334. logf: log.Printf,
  335. tailscaledPath: os.Getenv("TAILSCALED_PATH"),
  336. timeNow: time.Now,
  337. }
  338. l, err := net.Listen("tcp", "127.0.0.1:0")
  339. if err != nil {
  340. t.Fatal(err)
  341. }
  342. t.Cleanup(func() { l.Close() })
  343. go func() {
  344. conn, err := l.Accept()
  345. if err == nil {
  346. go srv.HandleSSHConn(&addressFakingConn{conn})
  347. }
  348. }()
  349. cl, err := ssh.Dial("tcp", l.Addr().String(), &ssh.ClientConfig{
  350. HostKeyCallback: ssh.InsecureIgnoreHostKey(),
  351. })
  352. if err != nil {
  353. log.Fatal(err)
  354. }
  355. t.Cleanup(func() { cl.Close() })
  356. return cl
  357. }
  358. func testSession(t *testing.T, forceV1Behavior bool) *session {
  359. cl := testClient(t, forceV1Behavior)
  360. return testSessionFor(t, cl)
  361. }
  362. func testSessionFor(t *testing.T, cl *ssh.Client) *session {
  363. s, err := cl.NewSession()
  364. if err != nil {
  365. log.Fatal(err)
  366. }
  367. t.Cleanup(func() { s.Close() })
  368. stdinReader, stdinWriter := io.Pipe()
  369. stdoutReader, stdoutWriter := io.Pipe()
  370. stderrReader, stderrWriter := io.Pipe()
  371. s.Stdin = stdinReader
  372. s.Stdout = io.MultiWriter(stdoutWriter, os.Stdout)
  373. s.Stderr = io.MultiWriter(stderrWriter, os.Stderr)
  374. return &session{
  375. Session: s,
  376. stdin: stdinWriter,
  377. stdout: stdoutReader,
  378. stderr: stderrReader,
  379. }
  380. }
  381. // testBackend implements ipnLocalBackend
  382. type testBackend struct {
  383. localUser string
  384. forceV1Behavior bool
  385. }
  386. func (tb *testBackend) GetSSH_HostKeys() ([]gossh.Signer, error) {
  387. var result []gossh.Signer
  388. for _, typ := range []string{"ed25519", "ecdsa", "rsa"} {
  389. var priv any
  390. var err error
  391. switch typ {
  392. case "ed25519":
  393. _, priv, err = ed25519.GenerateKey(rand.Reader)
  394. case "ecdsa":
  395. curve := elliptic.P256()
  396. priv, err = ecdsa.GenerateKey(curve, rand.Reader)
  397. case "rsa":
  398. const keySize = 2048
  399. priv, err = rsa.GenerateKey(rand.Reader, keySize)
  400. }
  401. if err != nil {
  402. return nil, err
  403. }
  404. mk, err := x509.MarshalPKCS8PrivateKey(priv)
  405. if err != nil {
  406. return nil, err
  407. }
  408. hostKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk})
  409. signer, err := gossh.ParsePrivateKey(hostKey)
  410. if err != nil {
  411. return nil, err
  412. }
  413. result = append(result, signer)
  414. }
  415. return result, nil
  416. }
  417. func (tb *testBackend) ShouldRunSSH() bool {
  418. return true
  419. }
  420. func (tb *testBackend) NetMap() *netmap.NetworkMap {
  421. capMap := make(set.Set[tailcfg.NodeCapability])
  422. if tb.forceV1Behavior {
  423. capMap[tailcfg.NodeAttrSSHBehaviorV1] = struct{}{}
  424. }
  425. return &netmap.NetworkMap{
  426. SSHPolicy: &tailcfg.SSHPolicy{
  427. Rules: []*tailcfg.SSHRule{
  428. {
  429. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  430. Action: &tailcfg.SSHAction{Accept: true},
  431. SSHUsers: map[string]string{"*": tb.localUser},
  432. },
  433. },
  434. },
  435. AllCaps: capMap,
  436. }
  437. }
  438. func (tb *testBackend) WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
  439. return (&tailcfg.Node{}).View(), tailcfg.UserProfile{
  440. LoginName: tb.localUser + "@example.com",
  441. }, true
  442. }
  443. func (tb *testBackend) DoNoiseRequest(req *http.Request) (*http.Response, error) {
  444. return nil, nil
  445. }
  446. func (tb *testBackend) Dialer() *tsdial.Dialer {
  447. return nil
  448. }
  449. func (tb *testBackend) TailscaleVarRoot() string {
  450. return ""
  451. }
  452. func (tb *testBackend) NodeKey() key.NodePublic {
  453. return key.NodePublic{}
  454. }
  455. type addressFakingConn struct {
  456. net.Conn
  457. }
  458. func (conn *addressFakingConn) LocalAddr() net.Addr {
  459. return &net.TCPAddr{
  460. IP: net.ParseIP("100.100.100.101"),
  461. Port: 22,
  462. }
  463. }
  464. func (conn *addressFakingConn) RemoteAddr() net.Addr {
  465. return &net.TCPAddr{
  466. IP: net.ParseIP("100.100.100.102"),
  467. Port: 10002,
  468. }
  469. }