tailssh_test.go 33 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build linux || darwin
  4. package tailssh
  5. import (
  6. "bytes"
  7. "context"
  8. "crypto/ed25519"
  9. "crypto/rand"
  10. "crypto/sha256"
  11. "encoding/json"
  12. "errors"
  13. "fmt"
  14. "io"
  15. "net"
  16. "net/http"
  17. "net/http/httptest"
  18. "net/netip"
  19. "os"
  20. "os/exec"
  21. "os/user"
  22. "reflect"
  23. "runtime"
  24. "slices"
  25. "strconv"
  26. "strings"
  27. "sync"
  28. "sync/atomic"
  29. "testing"
  30. "time"
  31. gossh "github.com/tailscale/golang-x-crypto/ssh"
  32. "tailscale.com/ipn/ipnlocal"
  33. "tailscale.com/ipn/store/mem"
  34. "tailscale.com/net/memnet"
  35. "tailscale.com/net/tsdial"
  36. "tailscale.com/sessionrecording"
  37. "tailscale.com/tailcfg"
  38. "tailscale.com/tempfork/gliderlabs/ssh"
  39. "tailscale.com/tsd"
  40. "tailscale.com/tstest"
  41. "tailscale.com/types/key"
  42. "tailscale.com/types/logger"
  43. "tailscale.com/types/logid"
  44. "tailscale.com/types/netmap"
  45. "tailscale.com/types/ptr"
  46. "tailscale.com/util/cibuild"
  47. "tailscale.com/util/lineiter"
  48. "tailscale.com/util/must"
  49. "tailscale.com/version/distro"
  50. "tailscale.com/wgengine"
  51. )
  52. func TestMatchRule(t *testing.T) {
  53. someAction := new(tailcfg.SSHAction)
  54. tests := []struct {
  55. name string
  56. rule *tailcfg.SSHRule
  57. ci *sshConnInfo
  58. wantErr error
  59. wantUser string
  60. wantAcceptEnv []string
  61. }{
  62. {
  63. name: "invalid-conn",
  64. rule: &tailcfg.SSHRule{
  65. Action: someAction,
  66. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  67. SSHUsers: map[string]string{
  68. "*": "ubuntu",
  69. },
  70. },
  71. wantErr: errInvalidConn,
  72. },
  73. {
  74. name: "nil-rule",
  75. ci: &sshConnInfo{},
  76. rule: nil,
  77. wantErr: errNilRule,
  78. },
  79. {
  80. name: "nil-action",
  81. ci: &sshConnInfo{},
  82. rule: &tailcfg.SSHRule{},
  83. wantErr: errNilAction,
  84. },
  85. {
  86. name: "expired",
  87. rule: &tailcfg.SSHRule{
  88. Action: someAction,
  89. RuleExpires: ptr.To(time.Unix(100, 0)),
  90. },
  91. ci: &sshConnInfo{},
  92. wantErr: errRuleExpired,
  93. },
  94. {
  95. name: "no-principal",
  96. rule: &tailcfg.SSHRule{
  97. Action: someAction,
  98. SSHUsers: map[string]string{
  99. "*": "ubuntu",
  100. }},
  101. ci: &sshConnInfo{},
  102. wantErr: errPrincipalMatch,
  103. },
  104. {
  105. name: "no-user-match",
  106. rule: &tailcfg.SSHRule{
  107. Action: someAction,
  108. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  109. },
  110. ci: &sshConnInfo{sshUser: "alice"},
  111. wantErr: errUserMatch,
  112. },
  113. {
  114. name: "ok-wildcard",
  115. rule: &tailcfg.SSHRule{
  116. Action: someAction,
  117. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  118. SSHUsers: map[string]string{
  119. "*": "ubuntu",
  120. },
  121. },
  122. ci: &sshConnInfo{sshUser: "alice"},
  123. wantUser: "ubuntu",
  124. },
  125. {
  126. name: "ok-wildcard-and-nil-principal",
  127. rule: &tailcfg.SSHRule{
  128. Action: someAction,
  129. Principals: []*tailcfg.SSHPrincipal{
  130. nil, // don't crash on this
  131. {Any: true},
  132. },
  133. SSHUsers: map[string]string{
  134. "*": "ubuntu",
  135. },
  136. },
  137. ci: &sshConnInfo{sshUser: "alice"},
  138. wantUser: "ubuntu",
  139. },
  140. {
  141. name: "ok-exact",
  142. rule: &tailcfg.SSHRule{
  143. Action: someAction,
  144. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  145. SSHUsers: map[string]string{
  146. "*": "ubuntu",
  147. "alice": "thealice",
  148. },
  149. },
  150. ci: &sshConnInfo{sshUser: "alice"},
  151. wantUser: "thealice",
  152. },
  153. {
  154. name: "ok-with-accept-env",
  155. rule: &tailcfg.SSHRule{
  156. Action: someAction,
  157. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  158. SSHUsers: map[string]string{
  159. "*": "ubuntu",
  160. "alice": "thealice",
  161. },
  162. AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
  163. },
  164. ci: &sshConnInfo{sshUser: "alice"},
  165. wantUser: "thealice",
  166. wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
  167. },
  168. {
  169. name: "no-users-for-reject",
  170. rule: &tailcfg.SSHRule{
  171. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  172. Action: &tailcfg.SSHAction{Reject: true},
  173. },
  174. ci: &sshConnInfo{sshUser: "alice"},
  175. },
  176. {
  177. name: "match-principal-node-ip",
  178. rule: &tailcfg.SSHRule{
  179. Action: someAction,
  180. Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}},
  181. SSHUsers: map[string]string{"*": "ubuntu"},
  182. },
  183. ci: &sshConnInfo{src: netip.MustParseAddrPort("1.2.3.4:30343")},
  184. wantUser: "ubuntu",
  185. },
  186. {
  187. name: "match-principal-node-id",
  188. rule: &tailcfg.SSHRule{
  189. Action: someAction,
  190. Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}},
  191. SSHUsers: map[string]string{"*": "ubuntu"},
  192. },
  193. ci: &sshConnInfo{node: (&tailcfg.Node{StableID: "some-node-ID"}).View()},
  194. wantUser: "ubuntu",
  195. },
  196. {
  197. name: "match-principal-userlogin",
  198. rule: &tailcfg.SSHRule{
  199. Action: someAction,
  200. Principals: []*tailcfg.SSHPrincipal{{UserLogin: "[email protected]"}},
  201. SSHUsers: map[string]string{"*": "ubuntu"},
  202. },
  203. ci: &sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "[email protected]"}},
  204. wantUser: "ubuntu",
  205. },
  206. {
  207. name: "ssh-user-equal",
  208. rule: &tailcfg.SSHRule{
  209. Action: someAction,
  210. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  211. SSHUsers: map[string]string{
  212. "*": "=",
  213. },
  214. },
  215. ci: &sshConnInfo{sshUser: "alice"},
  216. wantUser: "alice",
  217. },
  218. }
  219. for _, tt := range tests {
  220. t.Run(tt.name, func(t *testing.T) {
  221. c := &conn{
  222. info: tt.ci,
  223. srv: &server{logf: t.Logf},
  224. }
  225. got, gotUser, gotAcceptEnv, err := c.matchRule(tt.rule, nil)
  226. if err != tt.wantErr {
  227. t.Errorf("err = %v; want %v", err, tt.wantErr)
  228. }
  229. if gotUser != tt.wantUser {
  230. t.Errorf("user = %q; want %q", gotUser, tt.wantUser)
  231. }
  232. if err == nil && got == nil {
  233. t.Errorf("expected non-nil action on success")
  234. }
  235. if !slices.Equal(gotAcceptEnv, tt.wantAcceptEnv) {
  236. t.Errorf("acceptEnv = %v; want %v", gotAcceptEnv, tt.wantAcceptEnv)
  237. }
  238. })
  239. }
  240. }
  241. func TestEvalSSHPolicy(t *testing.T) {
  242. someAction := new(tailcfg.SSHAction)
  243. tests := []struct {
  244. name string
  245. policy *tailcfg.SSHPolicy
  246. ci *sshConnInfo
  247. wantMatch bool
  248. wantUser string
  249. wantAcceptEnv []string
  250. }{
  251. {
  252. name: "multiple-matches-picks-first-match",
  253. policy: &tailcfg.SSHPolicy{
  254. Rules: []*tailcfg.SSHRule{
  255. {
  256. Action: someAction,
  257. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  258. SSHUsers: map[string]string{
  259. "other": "other1",
  260. },
  261. },
  262. {
  263. Action: someAction,
  264. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  265. SSHUsers: map[string]string{
  266. "*": "ubuntu",
  267. "alice": "thealice",
  268. },
  269. AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
  270. },
  271. {
  272. Action: someAction,
  273. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  274. SSHUsers: map[string]string{
  275. "other2": "other3",
  276. },
  277. },
  278. {
  279. Action: someAction,
  280. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  281. SSHUsers: map[string]string{
  282. "*": "ubuntu",
  283. "alice": "thealice",
  284. "mark": "markthe",
  285. },
  286. AcceptEnv: []string{"*"},
  287. },
  288. },
  289. },
  290. ci: &sshConnInfo{sshUser: "alice"},
  291. wantUser: "thealice",
  292. wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
  293. wantMatch: true,
  294. },
  295. {
  296. name: "no-matches-returns-failure",
  297. policy: &tailcfg.SSHPolicy{
  298. Rules: []*tailcfg.SSHRule{
  299. {
  300. Action: someAction,
  301. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  302. SSHUsers: map[string]string{
  303. "other": "other1",
  304. },
  305. },
  306. {
  307. Action: someAction,
  308. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  309. SSHUsers: map[string]string{
  310. "fedora": "ubuntu",
  311. },
  312. AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
  313. },
  314. {
  315. Action: someAction,
  316. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  317. SSHUsers: map[string]string{
  318. "other2": "other3",
  319. },
  320. },
  321. {
  322. Action: someAction,
  323. Principals: []*tailcfg.SSHPrincipal{{Any: true}},
  324. SSHUsers: map[string]string{
  325. "mark": "markthe",
  326. },
  327. AcceptEnv: []string{"*"},
  328. },
  329. },
  330. },
  331. ci: &sshConnInfo{sshUser: "alice"},
  332. wantUser: "",
  333. wantAcceptEnv: nil,
  334. wantMatch: false,
  335. },
  336. }
  337. for _, tt := range tests {
  338. t.Run(tt.name, func(t *testing.T) {
  339. c := &conn{
  340. info: tt.ci,
  341. srv: &server{logf: t.Logf},
  342. }
  343. got, gotUser, gotAcceptEnv, match := c.evalSSHPolicy(tt.policy, nil)
  344. if match != tt.wantMatch {
  345. t.Errorf("match = %v; want %v", match, tt.wantMatch)
  346. }
  347. if gotUser != tt.wantUser {
  348. t.Errorf("user = %q; want %q", gotUser, tt.wantUser)
  349. }
  350. if tt.wantMatch == true && got == nil {
  351. t.Errorf("expected non-nil action on success")
  352. }
  353. if !slices.Equal(gotAcceptEnv, tt.wantAcceptEnv) {
  354. t.Errorf("acceptEnv = %v; want %v", gotAcceptEnv, tt.wantAcceptEnv)
  355. }
  356. })
  357. }
  358. }
  359. // localState implements ipnLocalBackend for testing.
  360. type localState struct {
  361. sshEnabled bool
  362. matchingRule *tailcfg.SSHRule
  363. // serverActions is a map of the action name to the action.
  364. // It is served for paths like https://unused/ssh-action/<action-name>.
  365. // The action name is the last part of the action URL.
  366. serverActions map[string]*tailcfg.SSHAction
  367. }
  368. var (
  369. currentUser = os.Getenv("USER") // Use the current user for the test.
  370. testSigner gossh.Signer
  371. testSignerOnce sync.Once
  372. )
  373. func (ts *localState) Dialer() *tsdial.Dialer {
  374. return &tsdial.Dialer{}
  375. }
  376. func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) {
  377. testSignerOnce.Do(func() {
  378. _, priv, err := ed25519.GenerateKey(rand.Reader)
  379. if err != nil {
  380. panic(err)
  381. }
  382. s, err := gossh.NewSignerFromSigner(priv)
  383. if err != nil {
  384. panic(err)
  385. }
  386. testSigner = s
  387. })
  388. return []gossh.Signer{testSigner}, nil
  389. }
  390. func (ts *localState) ShouldRunSSH() bool {
  391. return ts.sshEnabled
  392. }
  393. func (ts *localState) NetMap() *netmap.NetworkMap {
  394. var policy *tailcfg.SSHPolicy
  395. if ts.matchingRule != nil {
  396. policy = &tailcfg.SSHPolicy{
  397. Rules: []*tailcfg.SSHRule{
  398. ts.matchingRule,
  399. },
  400. }
  401. }
  402. return &netmap.NetworkMap{
  403. SelfNode: (&tailcfg.Node{
  404. ID: 1,
  405. }).View(),
  406. SSHPolicy: policy,
  407. }
  408. }
  409. func (ts *localState) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
  410. if proto != "tcp" {
  411. return tailcfg.NodeView{}, tailcfg.UserProfile{}, false
  412. }
  413. return (&tailcfg.Node{
  414. ID: 2,
  415. StableID: "peer-id",
  416. }).View(), tailcfg.UserProfile{
  417. LoginName: "peer",
  418. }, true
  419. }
  420. func (ts *localState) DoNoiseRequest(req *http.Request) (*http.Response, error) {
  421. rec := httptest.NewRecorder()
  422. k, ok := strings.CutPrefix(req.URL.Path, "/ssh-action/")
  423. if !ok {
  424. rec.WriteHeader(http.StatusNotFound)
  425. }
  426. a, ok := ts.serverActions[k]
  427. if !ok {
  428. rec.WriteHeader(http.StatusNotFound)
  429. return rec.Result(), nil
  430. }
  431. rec.WriteHeader(http.StatusOK)
  432. if err := json.NewEncoder(rec).Encode(a); err != nil {
  433. return nil, err
  434. }
  435. return rec.Result(), nil
  436. }
  437. func (ts *localState) TailscaleVarRoot() string {
  438. return ""
  439. }
  440. func (ts *localState) NodeKey() key.NodePublic {
  441. return key.NewNode().Public()
  442. }
  443. func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule {
  444. return &tailcfg.SSHRule{
  445. SSHUsers: map[string]string{
  446. "*": currentUser,
  447. },
  448. Action: action,
  449. Principals: []*tailcfg.SSHPrincipal{
  450. {
  451. Any: true,
  452. },
  453. },
  454. }
  455. }
  456. func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
  457. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  458. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  459. }
  460. var handler http.HandlerFunc
  461. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  462. handler(w, r)
  463. }))
  464. defer recordingServer.Close()
  465. s := &server{
  466. logf: t.Logf,
  467. lb: &localState{
  468. sshEnabled: true,
  469. matchingRule: newSSHRule(
  470. &tailcfg.SSHAction{
  471. Accept: true,
  472. Recorders: []netip.AddrPort{
  473. netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
  474. },
  475. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  476. RejectSessionWithMessage: "session rejected",
  477. TerminateSessionWithMessage: "session terminated",
  478. },
  479. },
  480. ),
  481. },
  482. }
  483. defer s.Shutdown()
  484. const sshUser = "alice"
  485. cfg := &gossh.ClientConfig{
  486. User: sshUser,
  487. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  488. }
  489. tests := []struct {
  490. name string
  491. handler func(w http.ResponseWriter, r *http.Request)
  492. sshCommand string
  493. wantClientOutput string
  494. clientOutputMustNotContain []string
  495. }{
  496. {
  497. name: "upload-denied",
  498. handler: func(w http.ResponseWriter, r *http.Request) {
  499. w.WriteHeader(http.StatusForbidden)
  500. },
  501. sshCommand: "echo hello",
  502. wantClientOutput: "session rejected\r\n",
  503. clientOutputMustNotContain: []string{"hello"},
  504. },
  505. {
  506. name: "upload-fails-after-starting",
  507. handler: func(w http.ResponseWriter, r *http.Request) {
  508. r.Body.Read(make([]byte, 1))
  509. time.Sleep(100 * time.Millisecond)
  510. w.WriteHeader(http.StatusInternalServerError)
  511. },
  512. sshCommand: "echo hello && sleep 1 && echo world",
  513. wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n",
  514. clientOutputMustNotContain: []string{"world"},
  515. },
  516. }
  517. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  518. for _, tt := range tests {
  519. t.Run(tt.name, func(t *testing.T) {
  520. tstest.Replace(t, &handler, tt.handler)
  521. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  522. var wg sync.WaitGroup
  523. wg.Add(1)
  524. go func() {
  525. defer wg.Done()
  526. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  527. if err != nil {
  528. t.Errorf("client: %v", err)
  529. return
  530. }
  531. client := gossh.NewClient(c, chans, reqs)
  532. defer client.Close()
  533. session, err := client.NewSession()
  534. if err != nil {
  535. t.Errorf("client: %v", err)
  536. return
  537. }
  538. defer session.Close()
  539. t.Logf("client established session")
  540. got, err := session.CombinedOutput(tt.sshCommand)
  541. if err != nil {
  542. t.Logf("client got: %q: %v", got, err)
  543. } else {
  544. t.Errorf("client did not get kicked out: %q", got)
  545. }
  546. gotStr := string(got)
  547. if !strings.HasSuffix(gotStr, tt.wantClientOutput) {
  548. t.Errorf("client got %q, want %q", got, tt.wantClientOutput)
  549. }
  550. for _, x := range tt.clientOutputMustNotContain {
  551. if strings.Contains(gotStr, x) {
  552. t.Errorf("client output must not contain %q", x)
  553. }
  554. }
  555. }()
  556. if err := s.HandleSSHConn(dc); err != nil {
  557. t.Errorf("unexpected error: %v", err)
  558. }
  559. wg.Wait()
  560. })
  561. }
  562. }
  563. func TestMultipleRecorders(t *testing.T) {
  564. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  565. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  566. }
  567. done := make(chan struct{})
  568. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  569. defer close(done)
  570. io.ReadAll(r.Body)
  571. w.WriteHeader(http.StatusOK)
  572. }))
  573. defer recordingServer.Close()
  574. badRecorder, err := net.Listen("tcp", ":0")
  575. if err != nil {
  576. t.Fatal(err)
  577. }
  578. badRecorderAddr := badRecorder.Addr().String()
  579. badRecorder.Close()
  580. badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  581. w.WriteHeader(500)
  582. }))
  583. defer badRecordingServer500.Close()
  584. badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  585. w.WriteHeader(200)
  586. }))
  587. defer badRecordingServer200.Close()
  588. s := &server{
  589. logf: t.Logf,
  590. lb: &localState{
  591. sshEnabled: true,
  592. matchingRule: newSSHRule(
  593. &tailcfg.SSHAction{
  594. Accept: true,
  595. Recorders: []netip.AddrPort{
  596. netip.MustParseAddrPort(badRecorderAddr),
  597. netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()),
  598. netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()),
  599. netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
  600. },
  601. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  602. RejectSessionWithMessage: "session rejected",
  603. TerminateSessionWithMessage: "session terminated",
  604. },
  605. },
  606. ),
  607. },
  608. }
  609. defer s.Shutdown()
  610. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  611. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  612. const sshUser = "alice"
  613. cfg := &gossh.ClientConfig{
  614. User: sshUser,
  615. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  616. }
  617. var wg sync.WaitGroup
  618. wg.Add(1)
  619. go func() {
  620. defer wg.Done()
  621. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  622. if err != nil {
  623. t.Errorf("client: %v", err)
  624. return
  625. }
  626. client := gossh.NewClient(c, chans, reqs)
  627. defer client.Close()
  628. session, err := client.NewSession()
  629. if err != nil {
  630. t.Errorf("client: %v", err)
  631. return
  632. }
  633. defer session.Close()
  634. t.Logf("client established session")
  635. out, err := session.CombinedOutput("echo Ran echo!")
  636. if err != nil {
  637. t.Errorf("client: %v", err)
  638. }
  639. if string(out) != "Ran echo!\n" {
  640. t.Errorf("client: unexpected output: %q", out)
  641. }
  642. }()
  643. if err := s.HandleSSHConn(dc); err != nil {
  644. t.Errorf("unexpected error: %v", err)
  645. }
  646. wg.Wait()
  647. select {
  648. case <-done:
  649. case <-time.After(1 * time.Second):
  650. t.Fatal("timed out waiting for recording")
  651. }
  652. }
  653. // TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
  654. // when the client is not interactive (i.e. no PTY).
  655. // It starts a local SSH server and a recording server. The recording server
  656. // records the SSH session and returns it to the test.
  657. // The test then verifies that the recording has a valid CastHeader, it does not
  658. // validate the contents of the recording.
  659. func TestSSHRecordingNonInteractive(t *testing.T) {
  660. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  661. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  662. }
  663. var recording []byte
  664. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  665. recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  666. defer cancel()
  667. var err error
  668. recording, err = io.ReadAll(r.Body)
  669. if err != nil {
  670. t.Error(err)
  671. return
  672. }
  673. }))
  674. defer recordingServer.Close()
  675. s := &server{
  676. logf: logger.Discard,
  677. lb: &localState{
  678. sshEnabled: true,
  679. matchingRule: newSSHRule(
  680. &tailcfg.SSHAction{
  681. Accept: true,
  682. Recorders: []netip.AddrPort{
  683. must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
  684. },
  685. OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
  686. RejectSessionWithMessage: "session rejected",
  687. TerminateSessionWithMessage: "session terminated",
  688. },
  689. },
  690. ),
  691. },
  692. }
  693. defer s.Shutdown()
  694. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  695. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  696. const sshUser = "alice"
  697. cfg := &gossh.ClientConfig{
  698. User: sshUser,
  699. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  700. }
  701. var wg sync.WaitGroup
  702. wg.Add(1)
  703. go func() {
  704. defer wg.Done()
  705. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  706. if err != nil {
  707. t.Errorf("client: %v", err)
  708. return
  709. }
  710. client := gossh.NewClient(c, chans, reqs)
  711. defer client.Close()
  712. session, err := client.NewSession()
  713. if err != nil {
  714. t.Errorf("client: %v", err)
  715. return
  716. }
  717. defer session.Close()
  718. t.Logf("client established session")
  719. _, err = session.CombinedOutput("echo Ran echo!")
  720. if err != nil {
  721. t.Errorf("client: %v", err)
  722. }
  723. }()
  724. if err := s.HandleSSHConn(dc); err != nil {
  725. t.Errorf("unexpected error: %v", err)
  726. }
  727. wg.Wait()
  728. <-ctx.Done() // wait for recording to finish
  729. var ch sessionrecording.CastHeader
  730. if err := json.NewDecoder(bytes.NewReader(recording)).Decode(&ch); err != nil {
  731. t.Fatal(err)
  732. }
  733. if ch.SSHUser != sshUser {
  734. t.Errorf("SSHUser = %q; want %q", ch.SSHUser, sshUser)
  735. }
  736. if ch.Command != "echo Ran echo!" {
  737. t.Errorf("Command = %q; want %q", ch.Command, "echo Ran echo!")
  738. }
  739. }
  740. func TestSSHAuthFlow(t *testing.T) {
  741. if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
  742. t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
  743. }
  744. acceptRule := newSSHRule(&tailcfg.SSHAction{
  745. Accept: true,
  746. Message: "Welcome to Tailscale SSH!",
  747. })
  748. rejectRule := newSSHRule(&tailcfg.SSHAction{
  749. Reject: true,
  750. Message: "Go Away!",
  751. })
  752. tests := []struct {
  753. name string
  754. sshUser string // defaults to alice
  755. state *localState
  756. wantBanners []string
  757. usesPassword bool
  758. authErr bool
  759. }{
  760. {
  761. name: "no-policy",
  762. state: &localState{
  763. sshEnabled: true,
  764. },
  765. authErr: true,
  766. },
  767. {
  768. name: "accept",
  769. state: &localState{
  770. sshEnabled: true,
  771. matchingRule: acceptRule,
  772. },
  773. wantBanners: []string{"Welcome to Tailscale SSH!"},
  774. },
  775. {
  776. name: "reject",
  777. state: &localState{
  778. sshEnabled: true,
  779. matchingRule: rejectRule,
  780. },
  781. wantBanners: []string{"Go Away!"},
  782. authErr: true,
  783. },
  784. {
  785. name: "simple-check",
  786. state: &localState{
  787. sshEnabled: true,
  788. matchingRule: newSSHRule(&tailcfg.SSHAction{
  789. HoldAndDelegate: "https://unused/ssh-action/accept",
  790. }),
  791. serverActions: map[string]*tailcfg.SSHAction{
  792. "accept": acceptRule.Action,
  793. },
  794. },
  795. wantBanners: []string{"Welcome to Tailscale SSH!"},
  796. },
  797. {
  798. name: "multi-check",
  799. state: &localState{
  800. sshEnabled: true,
  801. matchingRule: newSSHRule(&tailcfg.SSHAction{
  802. Message: "First",
  803. HoldAndDelegate: "https://unused/ssh-action/check1",
  804. }),
  805. serverActions: map[string]*tailcfg.SSHAction{
  806. "check1": {
  807. Message: "url-here",
  808. HoldAndDelegate: "https://unused/ssh-action/check2",
  809. },
  810. "check2": acceptRule.Action,
  811. },
  812. },
  813. wantBanners: []string{"First", "url-here", "Welcome to Tailscale SSH!"},
  814. },
  815. {
  816. name: "check-reject",
  817. state: &localState{
  818. sshEnabled: true,
  819. matchingRule: newSSHRule(&tailcfg.SSHAction{
  820. Message: "First",
  821. HoldAndDelegate: "https://unused/ssh-action/reject",
  822. }),
  823. serverActions: map[string]*tailcfg.SSHAction{
  824. "reject": rejectRule.Action,
  825. },
  826. },
  827. wantBanners: []string{"First", "Go Away!"},
  828. authErr: true,
  829. },
  830. {
  831. name: "force-password-auth",
  832. sshUser: "alice+password",
  833. state: &localState{
  834. sshEnabled: true,
  835. matchingRule: acceptRule,
  836. },
  837. usesPassword: true,
  838. wantBanners: []string{"Welcome to Tailscale SSH!"},
  839. },
  840. }
  841. s := &server{
  842. logf: logger.Discard,
  843. }
  844. defer s.Shutdown()
  845. src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
  846. for _, tc := range tests {
  847. t.Run(tc.name, func(t *testing.T) {
  848. sc, dc := memnet.NewTCPConn(src, dst, 1024)
  849. s.lb = tc.state
  850. sshUser := "alice"
  851. if tc.sshUser != "" {
  852. sshUser = tc.sshUser
  853. }
  854. var passwordUsed atomic.Bool
  855. cfg := &gossh.ClientConfig{
  856. User: sshUser,
  857. HostKeyCallback: gossh.InsecureIgnoreHostKey(),
  858. Auth: []gossh.AuthMethod{
  859. gossh.PasswordCallback(func() (secret string, err error) {
  860. if !tc.usesPassword {
  861. t.Error("unexpected use of PasswordCallback")
  862. return "", errors.New("unexpected use of PasswordCallback")
  863. }
  864. passwordUsed.Store(true)
  865. return "any-pass", nil
  866. }),
  867. },
  868. BannerCallback: func(message string) error {
  869. if len(tc.wantBanners) == 0 {
  870. t.Errorf("unexpected banner: %q", message)
  871. } else if message != tc.wantBanners[0] {
  872. t.Errorf("banner = %q; want %q", message, tc.wantBanners[0])
  873. } else {
  874. t.Logf("banner = %q", message)
  875. tc.wantBanners = tc.wantBanners[1:]
  876. }
  877. return nil
  878. },
  879. }
  880. var wg sync.WaitGroup
  881. wg.Add(1)
  882. go func() {
  883. defer wg.Done()
  884. c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
  885. if err != nil {
  886. if !tc.authErr {
  887. t.Errorf("client: %v", err)
  888. }
  889. return
  890. } else if tc.authErr {
  891. c.Close()
  892. t.Errorf("client: expected error, got nil")
  893. return
  894. }
  895. client := gossh.NewClient(c, chans, reqs)
  896. defer client.Close()
  897. session, err := client.NewSession()
  898. if err != nil {
  899. t.Errorf("client: %v", err)
  900. return
  901. }
  902. defer session.Close()
  903. _, err = session.CombinedOutput("echo Ran echo!")
  904. if err != nil {
  905. t.Errorf("client: %v", err)
  906. }
  907. }()
  908. if err := s.HandleSSHConn(dc); err != nil {
  909. t.Errorf("unexpected error: %v", err)
  910. }
  911. wg.Wait()
  912. if len(tc.wantBanners) > 0 {
  913. t.Errorf("missing banners: %v", tc.wantBanners)
  914. }
  915. })
  916. }
  917. }
  918. func TestSSH(t *testing.T) {
  919. var logf logger.Logf = t.Logf
  920. sys := &tsd.System{}
  921. eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry())
  922. if err != nil {
  923. t.Fatal(err)
  924. }
  925. sys.Set(eng)
  926. sys.Set(new(mem.Store))
  927. lb, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{}, sys, 0)
  928. if err != nil {
  929. t.Fatal(err)
  930. }
  931. defer lb.Shutdown()
  932. dir := t.TempDir()
  933. lb.SetVarRoot(dir)
  934. srv := &server{
  935. lb: lb,
  936. logf: logf,
  937. }
  938. sc, err := srv.newConn()
  939. if err != nil {
  940. t.Fatal(err)
  941. }
  942. // Remove the auth checks for the test
  943. sc.insecureSkipTailscaleAuth = true
  944. u, err := user.Current()
  945. if err != nil {
  946. t.Fatal(err)
  947. }
  948. um, err := userLookup(u.Username)
  949. if err != nil {
  950. t.Fatal(err)
  951. }
  952. sc.localUser = um
  953. sc.info = &sshConnInfo{
  954. sshUser: "test",
  955. src: netip.MustParseAddrPort("1.2.3.4:32342"),
  956. dst: netip.MustParseAddrPort("1.2.3.5:22"),
  957. node: (&tailcfg.Node{}).View(),
  958. uprof: tailcfg.UserProfile{},
  959. }
  960. sc.action0 = &tailcfg.SSHAction{Accept: true}
  961. sc.finalAction = sc.action0
  962. sc.Handler = func(s ssh.Session) {
  963. sc.newSSHSession(s).run()
  964. }
  965. ln, err := net.Listen("tcp4", "127.0.0.1:0")
  966. if err != nil {
  967. t.Fatal(err)
  968. }
  969. defer ln.Close()
  970. port := ln.Addr().(*net.TCPAddr).Port
  971. go func() {
  972. for {
  973. c, err := ln.Accept()
  974. if err != nil {
  975. if !errors.Is(err, net.ErrClosed) {
  976. t.Errorf("Accept: %v", err)
  977. }
  978. return
  979. }
  980. go sc.HandleConn(c)
  981. }
  982. }()
  983. execSSH := func(args ...string) *exec.Cmd {
  984. cmd := exec.Command("ssh",
  985. "-F",
  986. "none",
  987. "-v",
  988. "-p", fmt.Sprint(port),
  989. "-o", "StrictHostKeyChecking=no",
  990. "[email protected]")
  991. cmd.Args = append(cmd.Args, args...)
  992. return cmd
  993. }
  994. t.Run("env", func(t *testing.T) {
  995. if cibuild.On() {
  996. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  997. }
  998. cmd := execSSH("LANG=foo env")
  999. cmd.Env = append(os.Environ(), "LOCAL_ENV=bar")
  1000. got, err := cmd.CombinedOutput()
  1001. if err != nil {
  1002. t.Fatal(err, string(got))
  1003. }
  1004. m := parseEnv(got)
  1005. if got := m["USER"]; got == "" || got != u.Username {
  1006. t.Errorf("USER = %q; want %q", got, u.Username)
  1007. }
  1008. if got := m["HOME"]; got == "" || got != u.HomeDir {
  1009. t.Errorf("HOME = %q; want %q", got, u.HomeDir)
  1010. }
  1011. if got := m["PWD"]; got == "" || got != u.HomeDir {
  1012. t.Errorf("PWD = %q; want %q", got, u.HomeDir)
  1013. }
  1014. if got := m["SHELL"]; got == "" {
  1015. t.Errorf("no SHELL")
  1016. }
  1017. if got, want := m["LANG"], "foo"; got != want {
  1018. t.Errorf("LANG = %q; want %q", got, want)
  1019. }
  1020. if got := m["LOCAL_ENV"]; got != "" {
  1021. t.Errorf("LOCAL_ENV leaked over ssh: %v", got)
  1022. }
  1023. t.Logf("got: %+v", m)
  1024. })
  1025. t.Run("stdout_stderr", func(t *testing.T) {
  1026. cmd := execSSH("sh", "-c", "echo foo; echo bar >&2")
  1027. var outBuf, errBuf bytes.Buffer
  1028. cmd.Stdout = &outBuf
  1029. cmd.Stderr = &errBuf
  1030. if err := cmd.Run(); err != nil {
  1031. t.Fatal(err)
  1032. }
  1033. t.Logf("Got: %q and %q", outBuf.Bytes(), errBuf.Bytes())
  1034. // TODO: figure out why these aren't right. should be
  1035. // "foo\n" and "bar\n", not "\n" and "bar\n".
  1036. })
  1037. t.Run("large_file", func(t *testing.T) {
  1038. const wantSize = 1e6
  1039. var outBuf bytes.Buffer
  1040. cmd := execSSH("head", "-c", strconv.Itoa(wantSize), "/dev/zero")
  1041. cmd.Stdout = &outBuf
  1042. if err := cmd.Run(); err != nil {
  1043. t.Fatal(err)
  1044. }
  1045. if gotSize := outBuf.Len(); gotSize != wantSize {
  1046. t.Fatalf("got %d, want %d", gotSize, int(wantSize))
  1047. }
  1048. })
  1049. t.Run("stdin", func(t *testing.T) {
  1050. if cibuild.On() {
  1051. t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
  1052. }
  1053. cmd := execSSH("cat")
  1054. var outBuf bytes.Buffer
  1055. cmd.Stdout = &outBuf
  1056. const str = "foo\nbar\n"
  1057. cmd.Stdin = strings.NewReader(str)
  1058. if err := cmd.Run(); err != nil {
  1059. t.Fatal(err)
  1060. }
  1061. if got := outBuf.String(); got != str {
  1062. t.Errorf("got %q; want %q", got, str)
  1063. }
  1064. })
  1065. }
  1066. func parseEnv(out []byte) map[string]string {
  1067. e := map[string]string{}
  1068. for line := range lineiter.Bytes(out) {
  1069. if i := bytes.IndexByte(line, '='); i != -1 {
  1070. e[string(line[:i])] = string(line[i+1:])
  1071. }
  1072. }
  1073. return e
  1074. }
  1075. func TestPublicKeyFetching(t *testing.T) {
  1076. var reqsTotal, reqsIfNoneMatchHit, reqsIfNoneMatchMiss int32
  1077. ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1078. atomic.AddInt32((&reqsTotal), 1)
  1079. etag := fmt.Sprintf("W/%q", sha256.Sum256([]byte(r.URL.Path)))
  1080. w.Header().Set("Etag", etag)
  1081. if v := r.Header.Get("If-None-Match"); v != "" {
  1082. if v == etag {
  1083. atomic.AddInt32(&reqsIfNoneMatchHit, 1)
  1084. w.WriteHeader(304)
  1085. return
  1086. }
  1087. atomic.AddInt32(&reqsIfNoneMatchMiss, 1)
  1088. }
  1089. io.WriteString(w, "foo\nbar\n"+string(r.URL.Path)+"\n")
  1090. }))
  1091. ts.StartTLS()
  1092. defer ts.Close()
  1093. keys := ts.URL
  1094. clock := &tstest.Clock{}
  1095. srv := &server{
  1096. pubKeyHTTPClient: ts.Client(),
  1097. timeNow: clock.Now,
  1098. }
  1099. for range 2 {
  1100. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  1101. if err != nil {
  1102. t.Fatal(err)
  1103. }
  1104. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  1105. t.Errorf("got %q; want %q", got, want)
  1106. }
  1107. }
  1108. if got, want := atomic.LoadInt32(&reqsTotal), int32(1); got != want {
  1109. t.Errorf("got %d requests; want %d", got, want)
  1110. }
  1111. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(0); got != want {
  1112. t.Errorf("got %d etag hits; want %d", got, want)
  1113. }
  1114. clock.Advance(5 * time.Minute)
  1115. got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
  1116. if err != nil {
  1117. t.Fatal(err)
  1118. }
  1119. if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
  1120. t.Errorf("got %q; want %q", got, want)
  1121. }
  1122. if got, want := atomic.LoadInt32(&reqsTotal), int32(2); got != want {
  1123. t.Errorf("got %d requests; want %d", got, want)
  1124. }
  1125. if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(1); got != want {
  1126. t.Errorf("got %d etag hits; want %d", got, want)
  1127. }
  1128. if got, want := atomic.LoadInt32(&reqsIfNoneMatchMiss), int32(0); got != want {
  1129. t.Errorf("got %d etag misses; want %d", got, want)
  1130. }
  1131. }
  1132. func TestExpandPublicKeyURL(t *testing.T) {
  1133. c := &conn{
  1134. info: &sshConnInfo{
  1135. uprof: tailcfg.UserProfile{
  1136. LoginName: "[email protected]",
  1137. },
  1138. },
  1139. }
  1140. if got, want := c.expandPublicKeyURL("foo"), "foo"; got != want {
  1141. t.Errorf("basic: got %q; want %q", got, want)
  1142. }
  1143. if got, want := c.expandPublicKeyURL("https://example.com/$LOGINNAME_LOCALPART.keys"), "https://example.com/bar.keys"; got != want {
  1144. t.Errorf("localpart: got %q; want %q", got, want)
  1145. }
  1146. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/[email protected]"; got != want {
  1147. t.Errorf("email: got %q; want %q", got, want)
  1148. }
  1149. c.info = new(sshConnInfo)
  1150. if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/keys?email="; got != want {
  1151. t.Errorf("on empty: got %q; want %q", got, want)
  1152. }
  1153. }
  1154. func TestAcceptEnvPair(t *testing.T) {
  1155. tests := []struct {
  1156. in string
  1157. want bool
  1158. }{
  1159. {"TERM=x", true},
  1160. {"term=x", false},
  1161. {"TERM", false},
  1162. {"LC_FOO=x", true},
  1163. {"LD_PRELOAD=naah", false},
  1164. {"TERM=screen-256color", true},
  1165. }
  1166. for _, tt := range tests {
  1167. if got := acceptEnvPair(tt.in); got != tt.want {
  1168. t.Errorf("for %q, got %v; want %v", tt.in, got, tt.want)
  1169. }
  1170. }
  1171. }
  1172. func TestPathFromPAMEnvLine(t *testing.T) {
  1173. u := &user.User{Username: "foo", HomeDir: "/Homes/Foo"}
  1174. tests := []struct {
  1175. line string
  1176. u *user.User
  1177. want string
  1178. }{
  1179. {"", u, ""},
  1180. {`PATH DEFAULT="/run/wrappers/bin:@{HOME}/.nix-profile/bin:/etc/profiles/per-user/@{PAM_USER}/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"`,
  1181. u, "/run/wrappers/bin:/Homes/Foo/.nix-profile/bin:/etc/profiles/per-user/foo/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"},
  1182. {`PATH DEFAULT="@{SOMETHING_ELSE}:nope:@{HOME}"`,
  1183. u, ""},
  1184. }
  1185. for i, tt := range tests {
  1186. got := pathFromPAMEnvLine([]byte(tt.line), tt.u)
  1187. if got != tt.want {
  1188. t.Errorf("%d. got %q; want %q", i, got, tt.want)
  1189. }
  1190. }
  1191. }
  1192. func TestExpandDefaultPathTmpl(t *testing.T) {
  1193. u := &user.User{Username: "foo", HomeDir: "/Homes/Foo"}
  1194. tests := []struct {
  1195. t string
  1196. u *user.User
  1197. want string
  1198. }{
  1199. {"", u, ""},
  1200. {`/run/wrappers/bin:@{HOME}/.nix-profile/bin:/etc/profiles/per-user/@{PAM_USER}/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin`,
  1201. u, "/run/wrappers/bin:/Homes/Foo/.nix-profile/bin:/etc/profiles/per-user/foo/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin"},
  1202. {`@{SOMETHING_ELSE}:nope:@{HOME}`, u, ""},
  1203. }
  1204. for i, tt := range tests {
  1205. got := expandDefaultPathTmpl(tt.t, tt.u)
  1206. if got != tt.want {
  1207. t.Errorf("%d. got %q; want %q", i, got, tt.want)
  1208. }
  1209. }
  1210. }
  1211. func TestPathFromPAMEnvLineOnNixOS(t *testing.T) {
  1212. if runtime.GOOS != "linux" {
  1213. t.Skip("skipping on non-linux")
  1214. }
  1215. if distro.Get() != distro.NixOS {
  1216. t.Skip("skipping on non-NixOS")
  1217. }
  1218. u, err := user.Current()
  1219. if err != nil {
  1220. t.Fatal(err)
  1221. }
  1222. got := defaultPathForUserOnNixOS(u)
  1223. if got == "" {
  1224. x, err := os.ReadFile("/etc/pam/environment")
  1225. t.Fatalf("no result. file was: err=%v, contents=%s", err, x)
  1226. }
  1227. t.Logf("success; got=%q", got)
  1228. }
  1229. func TestStdOsUserUserAssumptions(t *testing.T) {
  1230. v := reflect.TypeFor[user.User]()
  1231. if got, want := v.NumField(), 5; got != want {
  1232. t.Errorf("os/user.User has %v fields; this package assumes %v", got, want)
  1233. }
  1234. }