tailssh_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. // Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. //go:build linux || darwin
  5. // +build linux darwin
  6. package tailssh
  7. import (
  8. "bytes"
  9. "crypto/sha256"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "net"
  14. "net/http"
  15. "net/http/httptest"
  16. "os"
  17. "os/exec"
  18. "os/user"
  19. "reflect"
  20. "strings"
  21. "sync/atomic"
  22. "testing"
  23. "time"
  24. "inet.af/netaddr"
  25. "tailscale.com/ipn/ipnlocal"
  26. "tailscale.com/ipn/store/mem"
  27. "tailscale.com/net/tsdial"
  28. "tailscale.com/tailcfg"
  29. "tailscale.com/tempfork/gliderlabs/ssh"
  30. "tailscale.com/tstest"
  31. "tailscale.com/types/logger"
  32. "tailscale.com/util/cibuild"
  33. "tailscale.com/util/lineread"
  34. "tailscale.com/wgengine"
  35. )
  36. func TestMatchRule(t *testing.T) {
  37. someAction := new(tailcfg.SSHAction)
  38. tests := []struct {
  39. name string
  40. rule *tailcfg.SSHRule
  41. ci *sshConnInfo
  42. wantErr error
  43. wantUser string
  44. }{
  45. {
  46. name: "nil-rule",
  47. rule: nil,
  48. wantErr: errNilRule,
  49. },
  50. {
  51. name: "nil-action",
  52. rule: &tailcfg.SSHRule{},
  53. wantErr: errNilAction,
  54. },
  55. {
  56. name: "expired",
  57. rule: &tailcfg.SSHRule{
  58. Action: someAction,
  59. RuleExpires: timePtr(time.Unix(100, 0)),
  60. },
  61. ci: &sshConnInfo{},
  62. wantErr: errRuleExpired,
  63. },
  64. {
  65. name: "no-principal",
  66. rule: &tailcfg.SSHRule{
  67. Action: someAction,
  68. SSHUsers: map[string]string{
  69. "*": "ubuntu",
  70. }},
  71. ci: &sshConnInfo{},
  72. wantErr: errPrincipalMatch,
  73. },
  74. {
  75. name: "no-user-match",
  76. rule: &tailcfg.SSHRule{
  77. Action: someAction,
  78. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  79. },
  80. ci: &sshConnInfo{sshUser: "alice"},
  81. wantErr: errUserMatch,
  82. },
  83. {
  84. name: "ok-wildcard",
  85. rule: &tailcfg.SSHRule{
  86. Action: someAction,
  87. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  88. SSHUsers: map[string]string{
  89. "*": "ubuntu",
  90. },
  91. },
  92. ci: &sshConnInfo{sshUser: "alice"},
  93. wantUser: "ubuntu",
  94. },
  95. {
  96. name: "ok-wildcard-and-nil-principal",
  97. rule: &tailcfg.SSHRule{
  98. Action: someAction,
  99. Principals: []*tailcfg.SSHPrincipal{
  100. nil, // don't crash on this
  101. {Any: true},
  102. },
  103. SSHUsers: map[string]string{
  104. "*": "ubuntu",
  105. },
  106. },
  107. ci: &sshConnInfo{sshUser: "alice"},
  108. wantUser: "ubuntu",
  109. },
  110. {
  111. name: "ok-exact",
  112. rule: &tailcfg.SSHRule{
  113. Action: someAction,
  114. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  115. SSHUsers: map[string]string{
  116. "*": "ubuntu",
  117. "alice": "thealice",
  118. },
  119. },
  120. ci: &sshConnInfo{sshUser: "alice"},
  121. wantUser: "thealice",
  122. },
  123. {
  124. name: "no-users-for-reject",
  125. rule: &tailcfg.SSHRule{
  126. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  127. Action: &tailcfg.SSHAction{Reject: true},
  128. },
  129. ci: &sshConnInfo{sshUser: "alice"},
  130. },
  131. {
  132. name: "match-principal-node-ip",
  133. rule: &tailcfg.SSHRule{
  134. Action: someAction,
  135. Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}},
  136. SSHUsers: map[string]string{"*": "ubuntu"},
  137. },
  138. ci: &sshConnInfo{src: netaddr.MustParseIPPort("1.2.3.4:30343")},
  139. wantUser: "ubuntu",
  140. },
  141. {
  142. name: "match-principal-node-id",
  143. rule: &tailcfg.SSHRule{
  144. Action: someAction,
  145. Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}},
  146. SSHUsers: map[string]string{"*": "ubuntu"},
  147. },
  148. ci: &sshConnInfo{node: &tailcfg.Node{StableID: "some-node-ID"}},
  149. wantUser: "ubuntu",
  150. },
  151. {
  152. name: "match-principal-userlogin",
  153. rule: &tailcfg.SSHRule{
  154. Action: someAction,
  155. Principals: []*tailcfg.SSHPrincipal{{UserLogin: "[email protected]"}},
  156. SSHUsers: map[string]string{"*": "ubuntu"},
  157. },
  158. ci: &sshConnInfo{uprof: &tailcfg.UserProfile{LoginName: "[email protected]"}},
  159. wantUser: "ubuntu",
  160. },
  161. {
  162. name: "ssh-user-equal",
  163. rule: &tailcfg.SSHRule{
  164. Action: someAction,
  165. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  166. SSHUsers: map[string]string{
  167. "*": "=",
  168. },
  169. },
  170. ci: &sshConnInfo{sshUser: "alice"},
  171. wantUser: "alice",
  172. },
  173. }
  174. for _, tt := range tests {
  175. t.Run(tt.name, func(t *testing.T) {
  176. c := &conn{
  177. now: time.Unix(200, 0),
  178. info: tt.ci,
  179. }
  180. got, gotUser, err := c.matchRule(tt.rule, nil)
  181. if err != tt.wantErr {
  182. t.Errorf("err = %v; want %v", err, tt.wantErr)
  183. }
  184. if gotUser != tt.wantUser {
  185. t.Errorf("user = %q; want %q", gotUser, tt.wantUser)
  186. }
  187. if err == nil && got == nil {
  188. t.Errorf("expected non-nil action on success")
  189. }
  190. })
  191. }
  192. }
  193. func timePtr(t time.Time) *time.Time { return &t }
  194. func TestSSH(t *testing.T) {
  195. var logf logger.Logf = t.Logf
  196. eng, err := wgengine.NewFakeUserspaceEngine(logf, 0)
  197. if err != nil {
  198. t.Fatal(err)
  199. }
  200. lb, err := ipnlocal.NewLocalBackend(logf, "",
  201. new(mem.Store),
  202. new(tsdial.Dialer),
  203. eng, 0)
  204. if err != nil {
  205. t.Fatal(err)
  206. }
  207. defer lb.Shutdown()
  208. dir := t.TempDir()
  209. lb.SetVarRoot(dir)
  210. srv := &server{
  211. lb: lb,
  212. logf: logf,
  213. }
  214. sc, err := srv.newConn()
  215. if err != nil {
  216. t.Fatal(err)
  217. }
  218. // Remove the auth checks for the test
  219. sc.insecureSkipTailscaleAuth = true
  220. u, err := user.Current()
  221. if err != nil {
  222. t.Fatal(err)
  223. }
  224. sc.localUser = u
  225. sc.info = &sshConnInfo{
  226. sshUser: "test",
  227. src: netaddr.MustParseIPPort("1.2.3.4:32342"),
  228. dst: netaddr.MustParseIPPort("1.2.3.5:22"),
  229. node: &tailcfg.Node{},
  230. uprof: &tailcfg.UserProfile{},
  231. }
  232. sc.Handler = func(s ssh.Session) {
  233. sc.newSSHSession(s, &tailcfg.SSHAction{Accept: true}).run()
  234. }
  235. ln, err := net.Listen("tcp4", "127.0.0.1:0")
  236. if err != nil {
  237. t.Fatal(err)
  238. }
  239. defer ln.Close()
  240. port := ln.Addr().(*net.TCPAddr).Port
  241. go func() {
  242. for {
  243. c, err := ln.Accept()
  244. if err != nil {
  245. if !errors.Is(err, net.ErrClosed) {
  246. t.Errorf("Accept: %v", err)
  247. }
  248. return
  249. }
  250. go sc.HandleConn(c)
  251. }
  252. }()
  253. execSSH := func(args ...string) *exec.Cmd {
  254. cmd := exec.Command("ssh",
  255. "-F",
  256. "none",
  257. "-v",
  258. "-p", fmt.Sprint(port),
  259. "-o", "StrictHostKeyChecking=no",
  260. "[email protected]")
  261. cmd.Args = append(cmd.Args, args...)
  262. return cmd
  263. }
  264. t.Run("env", func(t *testing.T) {
  265. if cibuild.On() {
  266. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  267. }
  268. cmd := execSSH("LANG=foo env")
  269. cmd.Env = append(os.Environ(), "LOCAL_ENV=bar")
  270. got, err := cmd.CombinedOutput()
  271. if err != nil {
  272. t.Fatal(err, string(got))
  273. }
  274. m := parseEnv(got)
  275. if got := m["USER"]; got == "" || got != u.Username {
  276. t.Errorf("USER = %q; want %q", got, u.Username)
  277. }
  278. if got := m["HOME"]; got == "" || got != u.HomeDir {
  279. t.Errorf("HOME = %q; want %q", got, u.HomeDir)
  280. }
  281. if got := m["PWD"]; got == "" || got != u.HomeDir {
  282. t.Errorf("PWD = %q; want %q", got, u.HomeDir)
  283. }
  284. if got := m["SHELL"]; got == "" {
  285. t.Errorf("no SHELL")
  286. }
  287. if got, want := m["LANG"], "foo"; got != want {
  288. t.Errorf("LANG = %q; want %q", got, want)
  289. }
  290. if got := m["LOCAL_ENV"]; got != "" {
  291. t.Errorf("LOCAL_ENV leaked over ssh: %v", got)
  292. }
  293. t.Logf("got: %+v", m)
  294. })
  295. t.Run("stdout_stderr", func(t *testing.T) {
  296. cmd := execSSH("sh", "-c", "echo foo; echo bar >&2")
  297. var outBuf, errBuf bytes.Buffer
  298. cmd.Stdout = &outBuf
  299. cmd.Stderr = &errBuf
  300. if err := cmd.Run(); err != nil {
  301. t.Fatal(err)
  302. }
  303. t.Logf("Got: %q and %q", outBuf.Bytes(), errBuf.Bytes())
  304. // TODO: figure out why these aren't right. should be
  305. // "foo\n" and "bar\n", not "\n" and "bar\n".
  306. })
  307. t.Run("stdin", func(t *testing.T) {
  308. if cibuild.On() {
  309. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  310. }
  311. cmd := execSSH("cat")
  312. var outBuf bytes.Buffer
  313. cmd.Stdout = &outBuf
  314. const str = "foo\nbar\n"
  315. cmd.Stdin = strings.NewReader(str)
  316. if err := cmd.Run(); err != nil {
  317. t.Fatal(err)
  318. }
  319. if got := outBuf.String(); got != str {
  320. t.Errorf("got %q; want %q", got, str)
  321. }
  322. })
  323. }
  324. func parseEnv(out []byte) map[string]string {
  325. e := map[string]string{}
  326. lineread.Reader(bytes.NewReader(out), func(line []byte) error {
  327. i := bytes.IndexByte(line, '=')
  328. if i == -1 {
  329. return nil
  330. }
  331. e[string(line[:i])] = string(line[i+1:])
  332. return nil
  333. })
  334. return e
  335. }
  336. func TestPublicKeyFetching(t *testing.T) {
  337. var reqsTotal, reqsIfNoneMatchHit, reqsIfNoneMatchMiss int32
  338. ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  339. atomic.AddInt32((&reqsTotal), 1)
  340. etag := fmt.Sprintf("W/%q", sha256.Sum256([]byte(r.URL.Path)))
  341. w.Header().Set("Etag", etag)
  342. if v := r.Header.Get("If-None-Match"); v != "" {
  343. if v == etag {
  344. atomic.AddInt32(&reqsIfNoneMatchHit, 1)
  345. w.WriteHeader(304)
  346. return
  347. }
  348. atomic.AddInt32(&reqsIfNoneMatchMiss, 1)
  349. }
  350. io.WriteString(w, "foo\nbar\n"+string(r.URL.Path)+"\n")
  351. }))
  352. ts.StartTLS()
  353. defer ts.Close()
  354. keys := ts.URL
  355. clock := &tstest.Clock{}
  356. srv := &server{
  357. pubKeyHTTPClient: ts.Client(),
  358. timeNow: clock.Now,
  359. }
  360. for i := 0; i < 2; i++ {
  361. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  362. if err != nil {
  363. t.Fatal(err)
  364. }
  365. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  366. t.Errorf("got %q; want %q", got, want)
  367. }
  368. }
  369. if got, want := atomic.LoadInt32(&reqsTotal), int32(1); got != want {
  370. t.Errorf("got %d requests; want %d", got, want)
  371. }
  372. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(0); got != want {
  373. t.Errorf("got %d etag hits; want %d", got, want)
  374. }
  375. clock.Advance(5 * time.Minute)
  376. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  377. if err != nil {
  378. t.Fatal(err)
  379. }
  380. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  381. t.Errorf("got %q; want %q", got, want)
  382. }
  383. if got, want := atomic.LoadInt32(&reqsTotal), int32(2); got != want {
  384. t.Errorf("got %d requests; want %d", got, want)
  385. }
  386. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(1); got != want {
  387. t.Errorf("got %d etag hits; want %d", got, want)
  388. }
  389. if got, want := atomic.LoadInt32(&reqsIfNoneMatchMiss), int32(0); got != want {
  390. t.Errorf("got %d etag misses; want %d", got, want)
  391. }
  392. }
  393. func TestExpandPublicKeyURL(t *testing.T) {
  394. c := &conn{
  395. info: &sshConnInfo{
  396. uprof: &tailcfg.UserProfile{
  397. LoginName: "[email protected]",
  398. },
  399. },
  400. }
  401. if got, want := c.expandPublicKeyURL("foo"), "foo"; got != want {
  402. t.Errorf("basic: got %q; want %q", got, want)
  403. }
  404. if got, want := c.expandPublicKeyURL("https://example.com/$LOGINNAME_LOCALPART.keys"), "https://example.com/bar.keys"; got != want {
  405. t.Errorf("localpart: got %q; want %q", got, want)
  406. }
  407. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/[email protected]"; got != want {
  408. t.Errorf("email: got %q; want %q", got, want)
  409. }
  410. c.info = new(sshConnInfo)
  411. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/keys?email="; got != want {
  412. t.Errorf("on empty: got %q; want %q", got, want)
  413. }
  414. }
  415. func TestAcceptEnvPair(t *testing.T) {
  416. tests := []struct {
  417. in string
  418. want bool
  419. }{
  420. {"TERM=x", true},
  421. {"term=x", false},
  422. {"TERM", false},
  423. {"LC_FOO=x", true},
  424. {"LD_PRELOAD=naah", false},
  425. {"TERM=screen-256color", true},
  426. }
  427. for _, tt := range tests {
  428. if got := acceptEnvPair(tt.in); got != tt.want {
  429. t.Errorf("for %q, got %v; want %v", tt.in, got, tt.want)
  430. }
  431. }
  432. }