tailssh_test.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. // Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. //go:build linux || darwin
  5. // +build linux darwin
  6. package tailssh
  7. import (
  8. "bytes"
  9. "errors"
  10. "fmt"
  11. "net"
  12. "os"
  13. "os/exec"
  14. "os/user"
  15. "strings"
  16. "testing"
  17. "time"
  18. "github.com/tailscale/ssh"
  19. "inet.af/netaddr"
  20. "tailscale.com/ipn/ipnlocal"
  21. "tailscale.com/ipn/store/mem"
  22. "tailscale.com/net/tsdial"
  23. "tailscale.com/tailcfg"
  24. "tailscale.com/types/logger"
  25. "tailscale.com/util/lineread"
  26. "tailscale.com/wgengine"
  27. )
  28. func TestMatchRule(t *testing.T) {
  29. someAction := new(tailcfg.SSHAction)
  30. tests := []struct {
  31. name string
  32. rule *tailcfg.SSHRule
  33. ci *sshConnInfo
  34. wantErr error
  35. wantUser string
  36. }{
  37. {
  38. name: "nil-rule",
  39. rule: nil,
  40. wantErr: errNilRule,
  41. },
  42. {
  43. name: "nil-action",
  44. rule: &tailcfg.SSHRule{},
  45. wantErr: errNilAction,
  46. },
  47. {
  48. name: "expired",
  49. rule: &tailcfg.SSHRule{
  50. Action: someAction,
  51. RuleExpires: timePtr(time.Unix(100, 0)),
  52. },
  53. ci: &sshConnInfo{now: time.Unix(200, 0)},
  54. wantErr: errRuleExpired,
  55. },
  56. {
  57. name: "no-principal",
  58. rule: &tailcfg.SSHRule{
  59. Action: someAction,
  60. },
  61. wantErr: errPrincipalMatch,
  62. },
  63. {
  64. name: "no-user-match",
  65. rule: &tailcfg.SSHRule{
  66. Action: someAction,
  67. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  68. },
  69. ci: &sshConnInfo{sshUser: "alice"},
  70. wantErr: errUserMatch,
  71. },
  72. {
  73. name: "ok-wildcard",
  74. rule: &tailcfg.SSHRule{
  75. Action: someAction,
  76. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  77. SSHUsers: map[string]string{
  78. "*": "ubuntu",
  79. },
  80. },
  81. ci: &sshConnInfo{sshUser: "alice"},
  82. wantUser: "ubuntu",
  83. },
  84. {
  85. name: "ok-wildcard-and-nil-principal",
  86. rule: &tailcfg.SSHRule{
  87. Action: someAction,
  88. Principals: []*tailcfg.SSHPrincipal{
  89. nil, // don't crash on this
  90. {Any: true},
  91. },
  92. SSHUsers: map[string]string{
  93. "*": "ubuntu",
  94. },
  95. },
  96. ci: &sshConnInfo{sshUser: "alice"},
  97. wantUser: "ubuntu",
  98. },
  99. {
  100. name: "ok-exact",
  101. rule: &tailcfg.SSHRule{
  102. Action: someAction,
  103. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  104. SSHUsers: map[string]string{
  105. "*": "ubuntu",
  106. "alice": "thealice",
  107. },
  108. },
  109. ci: &sshConnInfo{sshUser: "alice"},
  110. wantUser: "thealice",
  111. },
  112. {
  113. name: "no-users-for-reject",
  114. rule: &tailcfg.SSHRule{
  115. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  116. Action: &tailcfg.SSHAction{Reject: true},
  117. },
  118. ci: &sshConnInfo{sshUser: "alice"},
  119. },
  120. {
  121. name: "match-principal-node-ip",
  122. rule: &tailcfg.SSHRule{
  123. Action: someAction,
  124. Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}},
  125. SSHUsers: map[string]string{"*": "ubuntu"},
  126. },
  127. ci: &sshConnInfo{src: netaddr.MustParseIPPort("1.2.3.4:30343")},
  128. wantUser: "ubuntu",
  129. },
  130. {
  131. name: "match-principal-node-id",
  132. rule: &tailcfg.SSHRule{
  133. Action: someAction,
  134. Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}},
  135. SSHUsers: map[string]string{"*": "ubuntu"},
  136. },
  137. ci: &sshConnInfo{node: &tailcfg.Node{StableID: "some-node-ID"}},
  138. wantUser: "ubuntu",
  139. },
  140. {
  141. name: "match-principal-userlogin",
  142. rule: &tailcfg.SSHRule{
  143. Action: someAction,
  144. Principals: []*tailcfg.SSHPrincipal{{UserLogin: "[email protected]"}},
  145. SSHUsers: map[string]string{"*": "ubuntu"},
  146. },
  147. ci: &sshConnInfo{uprof: &tailcfg.UserProfile{LoginName: "[email protected]"}},
  148. wantUser: "ubuntu",
  149. },
  150. }
  151. for _, tt := range tests {
  152. t.Run(tt.name, func(t *testing.T) {
  153. got, gotUser, err := matchRule(tt.rule, tt.ci)
  154. if err != tt.wantErr {
  155. t.Errorf("err = %v; want %v", err, tt.wantErr)
  156. }
  157. if gotUser != tt.wantUser {
  158. t.Errorf("user = %q; want %q", gotUser, tt.wantUser)
  159. }
  160. if err == nil && got == nil {
  161. t.Errorf("expected non-nil action on success")
  162. }
  163. })
  164. }
  165. }
  166. func timePtr(t time.Time) *time.Time { return &t }
  167. func TestSSH(t *testing.T) {
  168. var logf logger.Logf = t.Logf
  169. eng, err := wgengine.NewFakeUserspaceEngine(logf, 0)
  170. if err != nil {
  171. t.Fatal(err)
  172. }
  173. lb, err := ipnlocal.NewLocalBackend(logf, "",
  174. new(mem.Store),
  175. new(tsdial.Dialer),
  176. eng, 0)
  177. if err != nil {
  178. t.Fatal(err)
  179. }
  180. defer lb.Shutdown()
  181. dir := t.TempDir()
  182. lb.SetVarRoot(dir)
  183. srv := &server{
  184. lb: lb,
  185. logf: logf,
  186. }
  187. ss, err := srv.newSSHServer()
  188. if err != nil {
  189. t.Fatal(err)
  190. }
  191. u, err := user.Current()
  192. if err != nil {
  193. t.Fatal(err)
  194. }
  195. ci := &sshConnInfo{
  196. sshUser: "test",
  197. src: netaddr.MustParseIPPort("1.2.3.4:32342"),
  198. dst: netaddr.MustParseIPPort("1.2.3.5:22"),
  199. node: &tailcfg.Node{},
  200. uprof: &tailcfg.UserProfile{},
  201. }
  202. ss.Handler = func(s ssh.Session) {
  203. ss := srv.newSSHSession(s, ci, u, &tailcfg.SSHAction{Accept: true})
  204. ss.run()
  205. }
  206. ln, err := net.Listen("tcp4", "127.0.0.1:0")
  207. if err != nil {
  208. t.Fatal(err)
  209. }
  210. defer ln.Close()
  211. port := ln.Addr().(*net.TCPAddr).Port
  212. go func() {
  213. for {
  214. c, err := ln.Accept()
  215. if err != nil {
  216. if !errors.Is(err, net.ErrClosed) {
  217. t.Errorf("Accept: %v", err)
  218. }
  219. return
  220. }
  221. go ss.HandleConn(c)
  222. }
  223. }()
  224. execSSH := func(args ...string) *exec.Cmd {
  225. cmd := exec.Command("ssh",
  226. "-p", fmt.Sprint(port),
  227. "-o", "StrictHostKeyChecking=no",
  228. "[email protected]")
  229. cmd.Args = append(cmd.Args, args...)
  230. return cmd
  231. }
  232. t.Run("env", func(t *testing.T) {
  233. cmd := execSSH("LANG=foo env")
  234. cmd.Env = append(os.Environ(), "LOCAL_ENV=bar")
  235. got, err := cmd.CombinedOutput()
  236. if err != nil {
  237. t.Fatal(err)
  238. }
  239. m := parseEnv(got)
  240. if got := m["USER"]; got == "" || got != u.Username {
  241. if u.Username == "runner" {
  242. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  243. }
  244. t.Errorf("USER = %q; want %q", got, u.Username)
  245. }
  246. if got := m["HOME"]; got == "" || got != u.HomeDir {
  247. t.Errorf("HOME = %q; want %q", got, u.HomeDir)
  248. }
  249. if got := m["PWD"]; got == "" || got != u.HomeDir {
  250. t.Errorf("PWD = %q; want %q", got, u.HomeDir)
  251. }
  252. if got := m["SHELL"]; got == "" {
  253. t.Errorf("no SHELL")
  254. }
  255. if got, want := m["LANG"], "foo"; got != want {
  256. t.Errorf("LANG = %q; want %q", got, want)
  257. }
  258. if got := m["LOCAL_ENV"]; got != "" {
  259. t.Errorf("LOCAL_ENV leaked over ssh: %v", got)
  260. }
  261. t.Logf("got: %+v", m)
  262. })
  263. t.Run("stdout_stderr", func(t *testing.T) {
  264. cmd := execSSH("sh", "-c", "echo foo; echo bar >&2")
  265. var outBuf, errBuf bytes.Buffer
  266. cmd.Stdout = &outBuf
  267. cmd.Stderr = &errBuf
  268. if err := cmd.Run(); err != nil {
  269. t.Fatal(err)
  270. }
  271. t.Logf("Got: %q and %q", outBuf.Bytes(), errBuf.Bytes())
  272. // TODO: figure out why these aren't right. should be
  273. // "foo\n" and "bar\n", not "\n" and "bar\n".
  274. })
  275. t.Run("stdin", func(t *testing.T) {
  276. cmd := execSSH("cat")
  277. var outBuf bytes.Buffer
  278. cmd.Stdout = &outBuf
  279. const str = "foo\nbar\n"
  280. cmd.Stdin = strings.NewReader(str)
  281. if err := cmd.Run(); err != nil {
  282. t.Fatal(err)
  283. }
  284. if got := outBuf.String(); got != str {
  285. t.Errorf("got %q; want %q", got, str)
  286. }
  287. })
  288. }
  289. func parseEnv(out []byte) map[string]string {
  290. e := map[string]string{}
  291. lineread.Reader(bytes.NewReader(out), func(line []byte) error {
  292. i := bytes.IndexByte(line, '=')
  293. if i == -1 {
  294. return nil
  295. }
  296. e[string(line[:i])] = string(line[i+1:])
  297. return nil
  298. })
  299. return e
  300. }