| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649 |
- // Copyright (c) Tailscale Inc & contributors
- // SPDX-License-Identifier: BSD-3-Clause
- package safeweb
- import (
- "io"
- "net/http"
- "net/http/httptest"
- "strconv"
- "strings"
- "testing"
- "time"
- "github.com/google/go-cmp/cmp"
- "github.com/gorilla/csrf"
- )
- func TestCompleteCORSConfig(t *testing.T) {
- _, err := NewServer(Config{AccessControlAllowOrigin: []string{"https://foobar.com"}})
- if err == nil {
- t.Fatalf("expected error when AccessControlAllowOrigin is provided without AccessControlAllowMethods")
- }
- _, err = NewServer(Config{AccessControlAllowMethods: []string{"GET", "POST"}})
- if err == nil {
- t.Fatalf("expected error when AccessControlAllowMethods is provided without AccessControlAllowOrigin")
- }
- _, err = NewServer(Config{AccessControlAllowOrigin: []string{"https://foobar.com"}, AccessControlAllowMethods: []string{"GET", "POST"}})
- if err != nil {
- t.Fatalf("error creating server with complete CORS configuration: %v", err)
- }
- }
- func TestPostRequestContentTypeValidation(t *testing.T) {
- tests := []struct {
- name string
- browserRoute bool
- contentType string
- wantErr bool
- }{
- {
- name: "API routes should accept `application/json` content-type",
- browserRoute: false,
- contentType: "application/json",
- wantErr: false,
- },
- {
- name: "API routes should reject `application/x-www-form-urlencoded` content-type",
- browserRoute: false,
- contentType: "application/x-www-form-urlencoded",
- wantErr: true,
- },
- {
- name: "Browser routes should accept `application/x-www-form-urlencoded` content-type",
- browserRoute: true,
- contentType: "application/x-www-form-urlencoded",
- wantErr: false,
- },
- {
- name: "non Browser routes should accept `application/json` content-type",
- browserRoute: true,
- contentType: "application/json",
- wantErr: false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- h := &http.ServeMux{}
- h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("ok"))
- }))
- var s *Server
- var err error
- if tt.browserRoute {
- s, err = NewServer(Config{BrowserMux: h})
- } else {
- s, err = NewServer(Config{APIMux: h})
- }
- if err != nil {
- t.Fatal(err)
- }
- defer s.Close()
- req := httptest.NewRequest("POST", "/", nil)
- req.Header.Set("Content-Type", tt.contentType)
- w := httptest.NewRecorder()
- s.h.Handler.ServeHTTP(w, req)
- resp := w.Result()
- if tt.wantErr && resp.StatusCode != http.StatusBadRequest {
- t.Fatalf("content type validation failed: got %v; want %v", resp.StatusCode, http.StatusBadRequest)
- }
- })
- }
- }
- func TestAPIMuxCrossOriginResourceSharingHeaders(t *testing.T) {
- tests := []struct {
- name string
- httpMethod string
- wantCORSHeaders bool
- corsOrigins []string
- corsMethods []string
- }{
- {
- name: "do not set CORS headers for non-OPTIONS requests",
- corsOrigins: []string{"https://foobar.com"},
- corsMethods: []string{"GET", "POST", "HEAD"},
- httpMethod: "GET",
- wantCORSHeaders: false,
- },
- {
- name: "set CORS headers for non-OPTIONS requests",
- corsOrigins: []string{"https://foobar.com"},
- corsMethods: []string{"GET", "POST", "HEAD"},
- httpMethod: "OPTIONS",
- wantCORSHeaders: true,
- },
- {
- name: "do not serve CORS headers for OPTIONS requests with no configured origins",
- httpMethod: "OPTIONS",
- wantCORSHeaders: false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- h := &http.ServeMux{}
- h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("ok"))
- }))
- s, err := NewServer(Config{
- APIMux: h,
- AccessControlAllowOrigin: tt.corsOrigins,
- AccessControlAllowMethods: tt.corsMethods,
- })
- if err != nil {
- t.Fatal(err)
- }
- defer s.Close()
- req := httptest.NewRequest(tt.httpMethod, "/", nil)
- w := httptest.NewRecorder()
- s.h.Handler.ServeHTTP(w, req)
- resp := w.Result()
- if (resp.Header.Get("Access-Control-Allow-Origin") == "") == tt.wantCORSHeaders {
- t.Fatalf("access-control-allow-origin want: %v; got: %v", tt.wantCORSHeaders, resp.Header.Get("Access-Control-Allow-Origin"))
- }
- })
- }
- }
- func TestCSRFProtection(t *testing.T) {
- tests := []struct {
- name string
- apiRoute bool
- passCSRFToken bool
- wantStatus int
- }{
- {
- name: "POST requests to non-API routes require CSRF token and fail if not provided",
- apiRoute: false,
- passCSRFToken: false,
- wantStatus: http.StatusForbidden,
- },
- {
- name: "POST requests to non-API routes require CSRF token and pass if provided",
- apiRoute: false,
- passCSRFToken: true,
- wantStatus: http.StatusOK,
- },
- {
- name: "POST requests to /api/ routes do not require CSRF token",
- apiRoute: true,
- passCSRFToken: false,
- wantStatus: http.StatusOK,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- h := &http.ServeMux{}
- h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("ok"))
- }))
- var s *Server
- var err error
- if tt.apiRoute {
- s, err = NewServer(Config{APIMux: h})
- } else {
- s, err = NewServer(Config{BrowserMux: h})
- }
- if err != nil {
- t.Fatal(err)
- }
- defer s.Close()
- // construct the test request
- req := httptest.NewRequest("POST", "/", nil)
- // send JSON for API routes, form data for browser routes
- if tt.apiRoute {
- req.Header.Set("Content-Type", "application/json")
- } else {
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- }
- // retrieve CSRF cookie & pass it in the test request
- // ref: https://github.com/gorilla/csrf/blob/main/csrf_test.go#L344-L347
- var token string
- if tt.passCSRFToken {
- h.Handle("/csrf", http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
- token = csrf.Token(r)
- }))
- get := httptest.NewRequest("GET", "/csrf", nil)
- w := httptest.NewRecorder()
- s.h.Handler.ServeHTTP(w, get)
- resp := w.Result()
- // pass the token & cookie in our subsequent test request
- req.Header.Set("X-CSRF-Token", token)
- for _, c := range resp.Cookies() {
- req.AddCookie(c)
- }
- }
- w := httptest.NewRecorder()
- s.h.Handler.ServeHTTP(w, req)
- resp := w.Result()
- if resp.StatusCode != tt.wantStatus {
- t.Fatalf("csrf protection check failed: got %v; want %v", resp.StatusCode, tt.wantStatus)
- }
- })
- }
- }
- func TestContentSecurityPolicyHeader(t *testing.T) {
- tests := []struct {
- name string
- csp CSP
- apiRoute bool
- wantCSP string
- }{
- {
- name: "default CSP",
- wantCSP: `base-uri 'self'; block-all-mixed-content; default-src 'self'; form-action 'self'; frame-ancestors 'none';`,
- },
- {
- name: "custom CSP",
- csp: CSP{
- "default-src": {"'self'", "https://tailscale.com"},
- "upgrade-insecure-requests": nil,
- },
- wantCSP: `default-src 'self' https://tailscale.com; upgrade-insecure-requests;`,
- },
- {
- name: "`/api/*` routes do not get CSP headers",
- apiRoute: true,
- wantCSP: "",
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- h := &http.ServeMux{}
- h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("ok"))
- }))
- var s *Server
- var err error
- if tt.apiRoute {
- s, err = NewServer(Config{APIMux: h, CSP: tt.csp})
- } else {
- s, err = NewServer(Config{BrowserMux: h, CSP: tt.csp})
- }
- if err != nil {
- t.Fatal(err)
- }
- defer s.Close()
- req := httptest.NewRequest("GET", "/", nil)
- w := httptest.NewRecorder()
- s.h.Handler.ServeHTTP(w, req)
- resp := w.Result()
- if got := resp.Header.Get("Content-Security-Policy"); got != tt.wantCSP {
- t.Fatalf("content security policy want: %q; got: %q", tt.wantCSP, got)
- }
- })
- }
- }
- func TestCSRFCookieSecureMode(t *testing.T) {
- tests := []struct {
- name string
- secureMode bool
- wantSecure bool
- }{
- {
- name: "CSRF cookie should be secure when server is in secure context",
- secureMode: true,
- wantSecure: true,
- },
- {
- name: "CSRF cookie should not be secure when server is not in secure context",
- secureMode: false,
- wantSecure: false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- h := &http.ServeMux{}
- h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("ok"))
- }))
- s, err := NewServer(Config{BrowserMux: h, SecureContext: tt.secureMode})
- if err != nil {
- t.Fatal(err)
- }
- defer s.Close()
- req := httptest.NewRequest("GET", "/", nil)
- w := httptest.NewRecorder()
- s.h.Handler.ServeHTTP(w, req)
- resp := w.Result()
- cookie := resp.Cookies()[0]
- if (cookie.Secure == tt.wantSecure) == false {
- t.Fatalf("csrf cookie secure flag want: %v; got: %v", tt.wantSecure, cookie.Secure)
- }
- })
- }
- }
- func TestRefererPolicy(t *testing.T) {
- tests := []struct {
- name string
- browserRoute bool
- wantRefererPolicy bool
- }{
- {
- name: "BrowserMux routes get Referer-Policy headers",
- browserRoute: true,
- wantRefererPolicy: true,
- },
- {
- name: "APIMux routes do not get Referer-Policy headers",
- browserRoute: false,
- wantRefererPolicy: false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- h := &http.ServeMux{}
- h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("ok"))
- }))
- var s *Server
- var err error
- if tt.browserRoute {
- s, err = NewServer(Config{BrowserMux: h})
- } else {
- s, err = NewServer(Config{APIMux: h})
- }
- if err != nil {
- t.Fatal(err)
- }
- defer s.Close()
- req := httptest.NewRequest("GET", "/", nil)
- w := httptest.NewRecorder()
- s.h.Handler.ServeHTTP(w, req)
- resp := w.Result()
- if (resp.Header.Get("Referer-Policy") == "") == tt.wantRefererPolicy {
- t.Fatalf("referer policy want: %v; got: %v", tt.wantRefererPolicy, resp.Header.Get("Referer-Policy"))
- }
- })
- }
- }
- func TestCSPAllowInlineStyles(t *testing.T) {
- for _, allow := range []bool{false, true} {
- t.Run(strconv.FormatBool(allow), func(t *testing.T) {
- h := &http.ServeMux{}
- h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("ok"))
- }))
- s, err := NewServer(Config{BrowserMux: h, CSPAllowInlineStyles: allow})
- if err != nil {
- t.Fatal(err)
- }
- defer s.Close()
- req := httptest.NewRequest("GET", "/", nil)
- w := httptest.NewRecorder()
- s.h.Handler.ServeHTTP(w, req)
- resp := w.Result()
- csp := resp.Header.Get("Content-Security-Policy")
- allowsStyles := strings.Contains(csp, "style-src 'self' 'unsafe-inline'")
- if allowsStyles != allow {
- t.Fatalf("CSP inline styles want: %v, got: %v in %q", allow, allowsStyles, csp)
- }
- })
- }
- }
- func TestRouting(t *testing.T) {
- for _, tt := range []struct {
- desc string
- browserPatterns []string
- apiPatterns []string
- requestPath string
- want string
- }{
- {
- desc: "only browser mux",
- browserPatterns: []string{"/"},
- requestPath: "/index.html",
- want: "browser",
- },
- {
- desc: "only API mux",
- apiPatterns: []string{"/api/"},
- requestPath: "/api/foo",
- want: "api",
- },
- {
- desc: "browser mux match",
- browserPatterns: []string{"/content/"},
- apiPatterns: []string{"/api/"},
- requestPath: "/content/index.html",
- want: "browser",
- },
- {
- desc: "API mux match",
- browserPatterns: []string{"/content/"},
- apiPatterns: []string{"/api/"},
- requestPath: "/api/foo",
- want: "api",
- },
- {
- desc: "browser wildcard match",
- browserPatterns: []string{"/"},
- apiPatterns: []string{"/api/"},
- requestPath: "/index.html",
- want: "browser",
- },
- {
- desc: "API wildcard match",
- browserPatterns: []string{"/content/"},
- apiPatterns: []string{"/"},
- requestPath: "/api/foo",
- want: "api",
- },
- {
- desc: "path conflict",
- browserPatterns: []string{"/foo/"},
- apiPatterns: []string{"/foo/bar/"},
- requestPath: "/foo/bar/baz",
- want: "api",
- },
- {
- desc: "no match",
- browserPatterns: []string{"/foo/"},
- apiPatterns: []string{"/bar/"},
- requestPath: "/baz",
- want: "404 page not found",
- },
- } {
- t.Run(tt.desc, func(t *testing.T) {
- bm := &http.ServeMux{}
- for _, p := range tt.browserPatterns {
- bm.HandleFunc(p, func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("browser"))
- })
- }
- am := &http.ServeMux{}
- for _, p := range tt.apiPatterns {
- am.HandleFunc(p, func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("api"))
- })
- }
- s, err := NewServer(Config{BrowserMux: bm, APIMux: am})
- if err != nil {
- t.Fatal(err)
- }
- defer s.Close()
- req := httptest.NewRequest("GET", tt.requestPath, nil)
- w := httptest.NewRecorder()
- s.h.Handler.ServeHTTP(w, req)
- resp, err := io.ReadAll(w.Result().Body)
- if err != nil {
- t.Fatal(err)
- }
- if got := strings.TrimSpace(string(resp)); got != tt.want {
- t.Errorf("got response %q, want %q", got, tt.want)
- }
- })
- }
- }
- func TestGetMoreSpecificPattern(t *testing.T) {
- for _, tt := range []struct {
- desc string
- a string
- b string
- want handlerType
- }{
- {
- desc: "identical",
- a: "/foo/bar",
- b: "/foo/bar",
- want: unknownHandler,
- },
- {
- desc: "identical prefix",
- a: "/foo/bar/",
- b: "/foo/bar/",
- want: unknownHandler,
- },
- {
- desc: "trailing slash",
- a: "/foo",
- b: "/foo/", // path.Clean will strip the trailing slash.
- want: unknownHandler,
- },
- {
- desc: "same prefix",
- a: "/foo/bar/quux",
- b: "/foo/bar/", // path.Clean will strip the trailing slash.
- want: apiHandler,
- },
- {
- desc: "almost same prefix, but not a path component",
- a: "/goat/sheep/cheese",
- b: "/goat/sheepcheese/", // path.Clean will strip the trailing slash.
- want: apiHandler,
- },
- {
- desc: "attempt to make less-specific pattern look more specific",
- a: "/goat/cat/buddy",
- b: "/goat/../../../../../../../cat", // path.Clean catches this foolishness
- want: apiHandler,
- },
- {
- desc: "2 names for / (1)",
- a: "/",
- b: "/../../../../../../",
- want: unknownHandler,
- },
- {
- desc: "2 names for / (2)",
- a: "/",
- b: "///////",
- want: unknownHandler,
- },
- {
- desc: "root-level",
- a: "/latest",
- b: "/", // path.Clean will NOT strip the trailing slash.
- want: apiHandler,
- },
- } {
- t.Run(tt.desc, func(t *testing.T) {
- got := checkHandlerType(tt.a, tt.b)
- if got != tt.want {
- t.Errorf("got %q, want %q", got, tt.want)
- }
- })
- }
- }
- func TestStrictTransportSecurityOptions(t *testing.T) {
- tests := []struct {
- name string
- options string
- secureContext bool
- expect string
- }{
- {
- name: "off by default",
- },
- {
- name: "default HSTS options in the secure context",
- secureContext: true,
- expect: DefaultStrictTransportSecurityOptions,
- },
- {
- name: "custom options sent in the secure context",
- options: DefaultStrictTransportSecurityOptions + "; includeSubDomains",
- secureContext: true,
- expect: DefaultStrictTransportSecurityOptions + "; includeSubDomains",
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- h := &http.ServeMux{}
- h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("ok"))
- }))
- s, err := NewServer(Config{BrowserMux: h, SecureContext: tt.secureContext, StrictTransportSecurityOptions: tt.options})
- if err != nil {
- t.Fatal(err)
- }
- defer s.Close()
- req := httptest.NewRequest("GET", "/", nil)
- w := httptest.NewRecorder()
- s.h.Handler.ServeHTTP(w, req)
- resp := w.Result()
- if cmp.Diff(tt.expect, resp.Header.Get("Strict-Transport-Security")) != "" {
- t.Fatalf("HSTS want: %q; got: %q", tt.expect, resp.Header.Get("Strict-Transport-Security"))
- }
- })
- }
- }
- func TestOverrideHTTPServer(t *testing.T) {
- s, err := NewServer(Config{})
- if err != nil {
- t.Fatalf("NewServer: %v", err)
- }
- if s.h.IdleTimeout != 0 {
- t.Fatalf("got %v; want 0", s.h.IdleTimeout)
- }
- c := http.Server{
- IdleTimeout: 10 * time.Second,
- }
- s, err = NewServer(Config{HTTPServer: &c})
- if err != nil {
- t.Fatalf("NewServer: %v", err)
- }
- if s.h.IdleTimeout != c.IdleTimeout {
- t.Fatalf("got %v; want %v", s.h.IdleTimeout, c.IdleTimeout)
- }
- }
|