http_test.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649
  1. // Copyright (c) Tailscale Inc & contributors
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package safeweb
  4. import (
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "strconv"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/google/go-cmp/cmp"
  13. "github.com/gorilla/csrf"
  14. )
  15. func TestCompleteCORSConfig(t *testing.T) {
  16. _, err := NewServer(Config{AccessControlAllowOrigin: []string{"https://foobar.com"}})
  17. if err == nil {
  18. t.Fatalf("expected error when AccessControlAllowOrigin is provided without AccessControlAllowMethods")
  19. }
  20. _, err = NewServer(Config{AccessControlAllowMethods: []string{"GET", "POST"}})
  21. if err == nil {
  22. t.Fatalf("expected error when AccessControlAllowMethods is provided without AccessControlAllowOrigin")
  23. }
  24. _, err = NewServer(Config{AccessControlAllowOrigin: []string{"https://foobar.com"}, AccessControlAllowMethods: []string{"GET", "POST"}})
  25. if err != nil {
  26. t.Fatalf("error creating server with complete CORS configuration: %v", err)
  27. }
  28. }
  29. func TestPostRequestContentTypeValidation(t *testing.T) {
  30. tests := []struct {
  31. name string
  32. browserRoute bool
  33. contentType string
  34. wantErr bool
  35. }{
  36. {
  37. name: "API routes should accept `application/json` content-type",
  38. browserRoute: false,
  39. contentType: "application/json",
  40. wantErr: false,
  41. },
  42. {
  43. name: "API routes should reject `application/x-www-form-urlencoded` content-type",
  44. browserRoute: false,
  45. contentType: "application/x-www-form-urlencoded",
  46. wantErr: true,
  47. },
  48. {
  49. name: "Browser routes should accept `application/x-www-form-urlencoded` content-type",
  50. browserRoute: true,
  51. contentType: "application/x-www-form-urlencoded",
  52. wantErr: false,
  53. },
  54. {
  55. name: "non Browser routes should accept `application/json` content-type",
  56. browserRoute: true,
  57. contentType: "application/json",
  58. wantErr: false,
  59. },
  60. }
  61. for _, tt := range tests {
  62. t.Run(tt.name, func(t *testing.T) {
  63. h := &http.ServeMux{}
  64. h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  65. w.Write([]byte("ok"))
  66. }))
  67. var s *Server
  68. var err error
  69. if tt.browserRoute {
  70. s, err = NewServer(Config{BrowserMux: h})
  71. } else {
  72. s, err = NewServer(Config{APIMux: h})
  73. }
  74. if err != nil {
  75. t.Fatal(err)
  76. }
  77. defer s.Close()
  78. req := httptest.NewRequest("POST", "/", nil)
  79. req.Header.Set("Content-Type", tt.contentType)
  80. w := httptest.NewRecorder()
  81. s.h.Handler.ServeHTTP(w, req)
  82. resp := w.Result()
  83. if tt.wantErr && resp.StatusCode != http.StatusBadRequest {
  84. t.Fatalf("content type validation failed: got %v; want %v", resp.StatusCode, http.StatusBadRequest)
  85. }
  86. })
  87. }
  88. }
  89. func TestAPIMuxCrossOriginResourceSharingHeaders(t *testing.T) {
  90. tests := []struct {
  91. name string
  92. httpMethod string
  93. wantCORSHeaders bool
  94. corsOrigins []string
  95. corsMethods []string
  96. }{
  97. {
  98. name: "do not set CORS headers for non-OPTIONS requests",
  99. corsOrigins: []string{"https://foobar.com"},
  100. corsMethods: []string{"GET", "POST", "HEAD"},
  101. httpMethod: "GET",
  102. wantCORSHeaders: false,
  103. },
  104. {
  105. name: "set CORS headers for non-OPTIONS requests",
  106. corsOrigins: []string{"https://foobar.com"},
  107. corsMethods: []string{"GET", "POST", "HEAD"},
  108. httpMethod: "OPTIONS",
  109. wantCORSHeaders: true,
  110. },
  111. {
  112. name: "do not serve CORS headers for OPTIONS requests with no configured origins",
  113. httpMethod: "OPTIONS",
  114. wantCORSHeaders: false,
  115. },
  116. }
  117. for _, tt := range tests {
  118. t.Run(tt.name, func(t *testing.T) {
  119. h := &http.ServeMux{}
  120. h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  121. w.Write([]byte("ok"))
  122. }))
  123. s, err := NewServer(Config{
  124. APIMux: h,
  125. AccessControlAllowOrigin: tt.corsOrigins,
  126. AccessControlAllowMethods: tt.corsMethods,
  127. })
  128. if err != nil {
  129. t.Fatal(err)
  130. }
  131. defer s.Close()
  132. req := httptest.NewRequest(tt.httpMethod, "/", nil)
  133. w := httptest.NewRecorder()
  134. s.h.Handler.ServeHTTP(w, req)
  135. resp := w.Result()
  136. if (resp.Header.Get("Access-Control-Allow-Origin") == "") == tt.wantCORSHeaders {
  137. t.Fatalf("access-control-allow-origin want: %v; got: %v", tt.wantCORSHeaders, resp.Header.Get("Access-Control-Allow-Origin"))
  138. }
  139. })
  140. }
  141. }
  142. func TestCSRFProtection(t *testing.T) {
  143. tests := []struct {
  144. name string
  145. apiRoute bool
  146. passCSRFToken bool
  147. wantStatus int
  148. }{
  149. {
  150. name: "POST requests to non-API routes require CSRF token and fail if not provided",
  151. apiRoute: false,
  152. passCSRFToken: false,
  153. wantStatus: http.StatusForbidden,
  154. },
  155. {
  156. name: "POST requests to non-API routes require CSRF token and pass if provided",
  157. apiRoute: false,
  158. passCSRFToken: true,
  159. wantStatus: http.StatusOK,
  160. },
  161. {
  162. name: "POST requests to /api/ routes do not require CSRF token",
  163. apiRoute: true,
  164. passCSRFToken: false,
  165. wantStatus: http.StatusOK,
  166. },
  167. }
  168. for _, tt := range tests {
  169. t.Run(tt.name, func(t *testing.T) {
  170. h := &http.ServeMux{}
  171. h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  172. w.Write([]byte("ok"))
  173. }))
  174. var s *Server
  175. var err error
  176. if tt.apiRoute {
  177. s, err = NewServer(Config{APIMux: h})
  178. } else {
  179. s, err = NewServer(Config{BrowserMux: h})
  180. }
  181. if err != nil {
  182. t.Fatal(err)
  183. }
  184. defer s.Close()
  185. // construct the test request
  186. req := httptest.NewRequest("POST", "/", nil)
  187. // send JSON for API routes, form data for browser routes
  188. if tt.apiRoute {
  189. req.Header.Set("Content-Type", "application/json")
  190. } else {
  191. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  192. }
  193. // retrieve CSRF cookie & pass it in the test request
  194. // ref: https://github.com/gorilla/csrf/blob/main/csrf_test.go#L344-L347
  195. var token string
  196. if tt.passCSRFToken {
  197. h.Handle("/csrf", http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
  198. token = csrf.Token(r)
  199. }))
  200. get := httptest.NewRequest("GET", "/csrf", nil)
  201. w := httptest.NewRecorder()
  202. s.h.Handler.ServeHTTP(w, get)
  203. resp := w.Result()
  204. // pass the token & cookie in our subsequent test request
  205. req.Header.Set("X-CSRF-Token", token)
  206. for _, c := range resp.Cookies() {
  207. req.AddCookie(c)
  208. }
  209. }
  210. w := httptest.NewRecorder()
  211. s.h.Handler.ServeHTTP(w, req)
  212. resp := w.Result()
  213. if resp.StatusCode != tt.wantStatus {
  214. t.Fatalf("csrf protection check failed: got %v; want %v", resp.StatusCode, tt.wantStatus)
  215. }
  216. })
  217. }
  218. }
  219. func TestContentSecurityPolicyHeader(t *testing.T) {
  220. tests := []struct {
  221. name string
  222. csp CSP
  223. apiRoute bool
  224. wantCSP string
  225. }{
  226. {
  227. name: "default CSP",
  228. wantCSP: `base-uri 'self'; block-all-mixed-content; default-src 'self'; form-action 'self'; frame-ancestors 'none';`,
  229. },
  230. {
  231. name: "custom CSP",
  232. csp: CSP{
  233. "default-src": {"'self'", "https://tailscale.com"},
  234. "upgrade-insecure-requests": nil,
  235. },
  236. wantCSP: `default-src 'self' https://tailscale.com; upgrade-insecure-requests;`,
  237. },
  238. {
  239. name: "`/api/*` routes do not get CSP headers",
  240. apiRoute: true,
  241. wantCSP: "",
  242. },
  243. }
  244. for _, tt := range tests {
  245. t.Run(tt.name, func(t *testing.T) {
  246. h := &http.ServeMux{}
  247. h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  248. w.Write([]byte("ok"))
  249. }))
  250. var s *Server
  251. var err error
  252. if tt.apiRoute {
  253. s, err = NewServer(Config{APIMux: h, CSP: tt.csp})
  254. } else {
  255. s, err = NewServer(Config{BrowserMux: h, CSP: tt.csp})
  256. }
  257. if err != nil {
  258. t.Fatal(err)
  259. }
  260. defer s.Close()
  261. req := httptest.NewRequest("GET", "/", nil)
  262. w := httptest.NewRecorder()
  263. s.h.Handler.ServeHTTP(w, req)
  264. resp := w.Result()
  265. if got := resp.Header.Get("Content-Security-Policy"); got != tt.wantCSP {
  266. t.Fatalf("content security policy want: %q; got: %q", tt.wantCSP, got)
  267. }
  268. })
  269. }
  270. }
  271. func TestCSRFCookieSecureMode(t *testing.T) {
  272. tests := []struct {
  273. name string
  274. secureMode bool
  275. wantSecure bool
  276. }{
  277. {
  278. name: "CSRF cookie should be secure when server is in secure context",
  279. secureMode: true,
  280. wantSecure: true,
  281. },
  282. {
  283. name: "CSRF cookie should not be secure when server is not in secure context",
  284. secureMode: false,
  285. wantSecure: false,
  286. },
  287. }
  288. for _, tt := range tests {
  289. t.Run(tt.name, func(t *testing.T) {
  290. h := &http.ServeMux{}
  291. h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  292. w.Write([]byte("ok"))
  293. }))
  294. s, err := NewServer(Config{BrowserMux: h, SecureContext: tt.secureMode})
  295. if err != nil {
  296. t.Fatal(err)
  297. }
  298. defer s.Close()
  299. req := httptest.NewRequest("GET", "/", nil)
  300. w := httptest.NewRecorder()
  301. s.h.Handler.ServeHTTP(w, req)
  302. resp := w.Result()
  303. cookie := resp.Cookies()[0]
  304. if (cookie.Secure == tt.wantSecure) == false {
  305. t.Fatalf("csrf cookie secure flag want: %v; got: %v", tt.wantSecure, cookie.Secure)
  306. }
  307. })
  308. }
  309. }
  310. func TestRefererPolicy(t *testing.T) {
  311. tests := []struct {
  312. name string
  313. browserRoute bool
  314. wantRefererPolicy bool
  315. }{
  316. {
  317. name: "BrowserMux routes get Referer-Policy headers",
  318. browserRoute: true,
  319. wantRefererPolicy: true,
  320. },
  321. {
  322. name: "APIMux routes do not get Referer-Policy headers",
  323. browserRoute: false,
  324. wantRefererPolicy: false,
  325. },
  326. }
  327. for _, tt := range tests {
  328. t.Run(tt.name, func(t *testing.T) {
  329. h := &http.ServeMux{}
  330. h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  331. w.Write([]byte("ok"))
  332. }))
  333. var s *Server
  334. var err error
  335. if tt.browserRoute {
  336. s, err = NewServer(Config{BrowserMux: h})
  337. } else {
  338. s, err = NewServer(Config{APIMux: h})
  339. }
  340. if err != nil {
  341. t.Fatal(err)
  342. }
  343. defer s.Close()
  344. req := httptest.NewRequest("GET", "/", nil)
  345. w := httptest.NewRecorder()
  346. s.h.Handler.ServeHTTP(w, req)
  347. resp := w.Result()
  348. if (resp.Header.Get("Referer-Policy") == "") == tt.wantRefererPolicy {
  349. t.Fatalf("referer policy want: %v; got: %v", tt.wantRefererPolicy, resp.Header.Get("Referer-Policy"))
  350. }
  351. })
  352. }
  353. }
  354. func TestCSPAllowInlineStyles(t *testing.T) {
  355. for _, allow := range []bool{false, true} {
  356. t.Run(strconv.FormatBool(allow), func(t *testing.T) {
  357. h := &http.ServeMux{}
  358. h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  359. w.Write([]byte("ok"))
  360. }))
  361. s, err := NewServer(Config{BrowserMux: h, CSPAllowInlineStyles: allow})
  362. if err != nil {
  363. t.Fatal(err)
  364. }
  365. defer s.Close()
  366. req := httptest.NewRequest("GET", "/", nil)
  367. w := httptest.NewRecorder()
  368. s.h.Handler.ServeHTTP(w, req)
  369. resp := w.Result()
  370. csp := resp.Header.Get("Content-Security-Policy")
  371. allowsStyles := strings.Contains(csp, "style-src 'self' 'unsafe-inline'")
  372. if allowsStyles != allow {
  373. t.Fatalf("CSP inline styles want: %v, got: %v in %q", allow, allowsStyles, csp)
  374. }
  375. })
  376. }
  377. }
  378. func TestRouting(t *testing.T) {
  379. for _, tt := range []struct {
  380. desc string
  381. browserPatterns []string
  382. apiPatterns []string
  383. requestPath string
  384. want string
  385. }{
  386. {
  387. desc: "only browser mux",
  388. browserPatterns: []string{"/"},
  389. requestPath: "/index.html",
  390. want: "browser",
  391. },
  392. {
  393. desc: "only API mux",
  394. apiPatterns: []string{"/api/"},
  395. requestPath: "/api/foo",
  396. want: "api",
  397. },
  398. {
  399. desc: "browser mux match",
  400. browserPatterns: []string{"/content/"},
  401. apiPatterns: []string{"/api/"},
  402. requestPath: "/content/index.html",
  403. want: "browser",
  404. },
  405. {
  406. desc: "API mux match",
  407. browserPatterns: []string{"/content/"},
  408. apiPatterns: []string{"/api/"},
  409. requestPath: "/api/foo",
  410. want: "api",
  411. },
  412. {
  413. desc: "browser wildcard match",
  414. browserPatterns: []string{"/"},
  415. apiPatterns: []string{"/api/"},
  416. requestPath: "/index.html",
  417. want: "browser",
  418. },
  419. {
  420. desc: "API wildcard match",
  421. browserPatterns: []string{"/content/"},
  422. apiPatterns: []string{"/"},
  423. requestPath: "/api/foo",
  424. want: "api",
  425. },
  426. {
  427. desc: "path conflict",
  428. browserPatterns: []string{"/foo/"},
  429. apiPatterns: []string{"/foo/bar/"},
  430. requestPath: "/foo/bar/baz",
  431. want: "api",
  432. },
  433. {
  434. desc: "no match",
  435. browserPatterns: []string{"/foo/"},
  436. apiPatterns: []string{"/bar/"},
  437. requestPath: "/baz",
  438. want: "404 page not found",
  439. },
  440. } {
  441. t.Run(tt.desc, func(t *testing.T) {
  442. bm := &http.ServeMux{}
  443. for _, p := range tt.browserPatterns {
  444. bm.HandleFunc(p, func(w http.ResponseWriter, r *http.Request) {
  445. w.Write([]byte("browser"))
  446. })
  447. }
  448. am := &http.ServeMux{}
  449. for _, p := range tt.apiPatterns {
  450. am.HandleFunc(p, func(w http.ResponseWriter, r *http.Request) {
  451. w.Write([]byte("api"))
  452. })
  453. }
  454. s, err := NewServer(Config{BrowserMux: bm, APIMux: am})
  455. if err != nil {
  456. t.Fatal(err)
  457. }
  458. defer s.Close()
  459. req := httptest.NewRequest("GET", tt.requestPath, nil)
  460. w := httptest.NewRecorder()
  461. s.h.Handler.ServeHTTP(w, req)
  462. resp, err := io.ReadAll(w.Result().Body)
  463. if err != nil {
  464. t.Fatal(err)
  465. }
  466. if got := strings.TrimSpace(string(resp)); got != tt.want {
  467. t.Errorf("got response %q, want %q", got, tt.want)
  468. }
  469. })
  470. }
  471. }
  472. func TestGetMoreSpecificPattern(t *testing.T) {
  473. for _, tt := range []struct {
  474. desc string
  475. a string
  476. b string
  477. want handlerType
  478. }{
  479. {
  480. desc: "identical",
  481. a: "/foo/bar",
  482. b: "/foo/bar",
  483. want: unknownHandler,
  484. },
  485. {
  486. desc: "identical prefix",
  487. a: "/foo/bar/",
  488. b: "/foo/bar/",
  489. want: unknownHandler,
  490. },
  491. {
  492. desc: "trailing slash",
  493. a: "/foo",
  494. b: "/foo/", // path.Clean will strip the trailing slash.
  495. want: unknownHandler,
  496. },
  497. {
  498. desc: "same prefix",
  499. a: "/foo/bar/quux",
  500. b: "/foo/bar/", // path.Clean will strip the trailing slash.
  501. want: apiHandler,
  502. },
  503. {
  504. desc: "almost same prefix, but not a path component",
  505. a: "/goat/sheep/cheese",
  506. b: "/goat/sheepcheese/", // path.Clean will strip the trailing slash.
  507. want: apiHandler,
  508. },
  509. {
  510. desc: "attempt to make less-specific pattern look more specific",
  511. a: "/goat/cat/buddy",
  512. b: "/goat/../../../../../../../cat", // path.Clean catches this foolishness
  513. want: apiHandler,
  514. },
  515. {
  516. desc: "2 names for / (1)",
  517. a: "/",
  518. b: "/../../../../../../",
  519. want: unknownHandler,
  520. },
  521. {
  522. desc: "2 names for / (2)",
  523. a: "/",
  524. b: "///////",
  525. want: unknownHandler,
  526. },
  527. {
  528. desc: "root-level",
  529. a: "/latest",
  530. b: "/", // path.Clean will NOT strip the trailing slash.
  531. want: apiHandler,
  532. },
  533. } {
  534. t.Run(tt.desc, func(t *testing.T) {
  535. got := checkHandlerType(tt.a, tt.b)
  536. if got != tt.want {
  537. t.Errorf("got %q, want %q", got, tt.want)
  538. }
  539. })
  540. }
  541. }
  542. func TestStrictTransportSecurityOptions(t *testing.T) {
  543. tests := []struct {
  544. name string
  545. options string
  546. secureContext bool
  547. expect string
  548. }{
  549. {
  550. name: "off by default",
  551. },
  552. {
  553. name: "default HSTS options in the secure context",
  554. secureContext: true,
  555. expect: DefaultStrictTransportSecurityOptions,
  556. },
  557. {
  558. name: "custom options sent in the secure context",
  559. options: DefaultStrictTransportSecurityOptions + "; includeSubDomains",
  560. secureContext: true,
  561. expect: DefaultStrictTransportSecurityOptions + "; includeSubDomains",
  562. },
  563. }
  564. for _, tt := range tests {
  565. t.Run(tt.name, func(t *testing.T) {
  566. h := &http.ServeMux{}
  567. h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  568. w.Write([]byte("ok"))
  569. }))
  570. s, err := NewServer(Config{BrowserMux: h, SecureContext: tt.secureContext, StrictTransportSecurityOptions: tt.options})
  571. if err != nil {
  572. t.Fatal(err)
  573. }
  574. defer s.Close()
  575. req := httptest.NewRequest("GET", "/", nil)
  576. w := httptest.NewRecorder()
  577. s.h.Handler.ServeHTTP(w, req)
  578. resp := w.Result()
  579. if cmp.Diff(tt.expect, resp.Header.Get("Strict-Transport-Security")) != "" {
  580. t.Fatalf("HSTS want: %q; got: %q", tt.expect, resp.Header.Get("Strict-Transport-Security"))
  581. }
  582. })
  583. }
  584. }
  585. func TestOverrideHTTPServer(t *testing.T) {
  586. s, err := NewServer(Config{})
  587. if err != nil {
  588. t.Fatalf("NewServer: %v", err)
  589. }
  590. if s.h.IdleTimeout != 0 {
  591. t.Fatalf("got %v; want 0", s.h.IdleTimeout)
  592. }
  593. c := http.Server{
  594. IdleTimeout: 10 * time.Second,
  595. }
  596. s, err = NewServer(Config{HTTPServer: &c})
  597. if err != nil {
  598. t.Fatalf("NewServer: %v", err)
  599. }
  600. if s.h.IdleTimeout != c.IdleTimeout {
  601. t.Fatalf("got %v; want %v", s.h.IdleTimeout, c.IdleTimeout)
  602. }
  603. }