| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619 |
- // Copyright (c) Tailscale Inc & AUTHORS
- // SPDX-License-Identifier: BSD-3-Clause
- package tsweb
- import (
- "bufio"
- "context"
- "errors"
- "fmt"
- "net"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
- "github.com/google/go-cmp/cmp"
- "tailscale.com/tstest"
- "tailscale.com/util/vizerror"
- )
- type noopHijacker struct {
- *httptest.ResponseRecorder
- hijacked bool
- }
- func (h *noopHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
- // Hijack "successfully" but don't bother returning a conn.
- h.hijacked = true
- return nil, nil, nil
- }
- type handlerFunc func(http.ResponseWriter, *http.Request) error
- func (f handlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) error {
- return f(w, r)
- }
- func TestStdHandler(t *testing.T) {
- const exampleRequestID = "example-request-id"
- var (
- handlerCode = func(code int) ReturnHandler {
- return handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
- w.WriteHeader(code)
- return nil
- })
- }
- handlerErr = func(code int, err error) ReturnHandler {
- return handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
- if code != 0 {
- w.WriteHeader(code)
- }
- return err
- })
- }
- req = func(ctx context.Context, url string) *http.Request {
- ret, err := http.NewRequestWithContext(ctx, "GET", url, nil)
- if err != nil {
- panic(err)
- }
- return ret
- }
- testErr = errors.New("test error")
- bgCtx = context.Background()
- // canceledCtx, cancel = context.WithCancel(bgCtx)
- startTime = time.Unix(1687870000, 1234)
- )
- // cancel()
- tests := []struct {
- name string
- rh ReturnHandler
- r *http.Request
- errHandler ErrorHandlerFunc
- wantCode int
- wantLog AccessLogRecord
- wantBody string
- }{
- {
- name: "handler returns 200",
- rh: handlerCode(200),
- r: req(bgCtx, "http://example.com/"),
- wantCode: 200,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- TLS: false,
- Host: "example.com",
- Method: "GET",
- Code: 200,
- RequestURI: "/",
- },
- },
- {
- name: "handler returns 200 with request ID",
- rh: handlerCode(200),
- r: req(bgCtx, "http://example.com/"),
- wantCode: 200,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- TLS: false,
- Host: "example.com",
- Method: "GET",
- Code: 200,
- RequestURI: "/",
- },
- },
- {
- name: "handler returns 404",
- rh: handlerCode(404),
- r: req(bgCtx, "http://example.com/foo"),
- wantCode: 404,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Code: 404,
- },
- },
- {
- name: "handler returns 404 with request ID",
- rh: handlerCode(404),
- r: req(bgCtx, "http://example.com/foo"),
- wantCode: 404,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Code: 404,
- },
- },
- {
- name: "handler returns 404 via HTTPError",
- rh: handlerErr(0, Error(404, "not found", testErr)),
- r: req(bgCtx, "http://example.com/foo"),
- wantCode: 404,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Err: "not found: " + testErr.Error(),
- Code: 404,
- },
- wantBody: "not found\n",
- },
- {
- name: "handler returns 404 via HTTPError with request ID",
- rh: handlerErr(0, Error(404, "not found", testErr)),
- r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
- wantCode: 404,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Err: "not found: " + testErr.Error(),
- Code: 404,
- RequestID: exampleRequestID,
- },
- wantBody: "not found\n" + exampleRequestID + "\n",
- },
- {
- name: "handler returns 404 with nil child error",
- rh: handlerErr(0, Error(404, "not found", nil)),
- r: req(bgCtx, "http://example.com/foo"),
- wantCode: 404,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Err: "not found",
- Code: 404,
- },
- wantBody: "not found\n",
- },
- {
- name: "handler returns 404 with request ID and nil child error",
- rh: handlerErr(0, Error(404, "not found", nil)),
- r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
- wantCode: 404,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Err: "not found",
- Code: 404,
- RequestID: exampleRequestID,
- },
- wantBody: "not found\n" + exampleRequestID + "\n",
- },
- {
- name: "handler returns user-visible error",
- rh: handlerErr(0, vizerror.New("visible error")),
- r: req(bgCtx, "http://example.com/foo"),
- wantCode: 500,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Err: "visible error",
- Code: 500,
- },
- wantBody: "visible error\n",
- },
- {
- name: "handler returns user-visible error with request ID",
- rh: handlerErr(0, vizerror.New("visible error")),
- r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
- wantCode: 500,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Err: "visible error",
- Code: 500,
- RequestID: exampleRequestID,
- },
- wantBody: "visible error\n" + exampleRequestID + "\n",
- },
- {
- name: "handler returns user-visible error wrapped by private error",
- rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))),
- r: req(bgCtx, "http://example.com/foo"),
- wantCode: 500,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Err: "visible error",
- Code: 500,
- },
- wantBody: "visible error\n",
- },
- {
- name: "handler returns user-visible error wrapped by private error with request ID",
- rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))),
- r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
- wantCode: 500,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Err: "visible error",
- Code: 500,
- RequestID: exampleRequestID,
- },
- wantBody: "visible error\n" + exampleRequestID + "\n",
- },
- {
- name: "handler returns generic error",
- rh: handlerErr(0, testErr),
- r: req(bgCtx, "http://example.com/foo"),
- wantCode: 500,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Err: testErr.Error(),
- Code: 500,
- },
- wantBody: "internal server error\n",
- },
- {
- name: "handler returns generic error with request ID",
- rh: handlerErr(0, testErr),
- r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
- wantCode: 500,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Err: testErr.Error(),
- Code: 500,
- RequestID: exampleRequestID,
- },
- wantBody: "internal server error\n" + exampleRequestID + "\n",
- },
- {
- name: "handler returns error after writing response",
- rh: handlerErr(200, testErr),
- r: req(bgCtx, "http://example.com/foo"),
- wantCode: 200,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Err: testErr.Error(),
- Code: 200,
- },
- },
- {
- name: "handler returns error after writing response with request ID",
- rh: handlerErr(200, testErr),
- r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
- wantCode: 200,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Err: testErr.Error(),
- Code: 200,
- RequestID: exampleRequestID,
- },
- },
- {
- name: "handler returns HTTPError after writing response",
- rh: handlerErr(200, Error(404, "not found", testErr)),
- r: req(bgCtx, "http://example.com/foo"),
- wantCode: 200,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Err: "not found: " + testErr.Error(),
- Code: 200,
- },
- },
- {
- name: "handler does nothing",
- rh: handlerFunc(func(http.ResponseWriter, *http.Request) error { return nil }),
- r: req(bgCtx, "http://example.com/foo"),
- wantCode: 200,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Code: 200,
- },
- },
- {
- name: "handler hijacks conn",
- rh: handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
- _, _, err := w.(http.Hijacker).Hijack()
- if err != nil {
- t.Errorf("couldn't hijack: %v", err)
- }
- return err
- }),
- r: req(bgCtx, "http://example.com/foo"),
- wantCode: 200,
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- Host: "example.com",
- Method: "GET",
- RequestURI: "/foo",
- Code: 101,
- },
- },
- {
- name: "error handler gets run",
- rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
- r: req(bgCtx, "http://example.com/"),
- wantCode: 200,
- errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) {
- http.Error(w, e.Msg, 200)
- },
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- TLS: false,
- Host: "example.com",
- Method: "GET",
- Code: 404,
- Err: "not found",
- RequestURI: "/",
- },
- wantBody: "not found\n",
- },
- {
- name: "error handler gets run with request ID",
- rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
- r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/"),
- wantCode: 200,
- errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) {
- requestID := RequestIDFromContext(r.Context())
- http.Error(w, fmt.Sprintf("%s with request ID %s", e.Msg, requestID), 200)
- },
- wantLog: AccessLogRecord{
- When: startTime,
- Seconds: 1.0,
- Proto: "HTTP/1.1",
- TLS: false,
- Host: "example.com",
- Method: "GET",
- Code: 404,
- Err: "not found",
- RequestURI: "/",
- RequestID: exampleRequestID,
- },
- wantBody: "not found with request ID " + exampleRequestID + "\n",
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- var logs []AccessLogRecord
- logf := func(fmt string, args ...any) {
- if fmt == "%s" {
- logs = append(logs, args[0].(AccessLogRecord))
- }
- t.Logf(fmt, args...)
- }
- clock := tstest.NewClock(tstest.ClockOpts{
- Start: startTime,
- Step: time.Second,
- })
- rec := noopHijacker{httptest.NewRecorder(), false}
- h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, OnError: test.errHandler})
- h.ServeHTTP(&rec, test.r)
- res := rec.Result()
- if res.StatusCode != test.wantCode {
- t.Errorf("HTTP code = %v, want %v", res.StatusCode, test.wantCode)
- }
- if len(logs) != 1 {
- t.Errorf("handler didn't write a request log")
- return
- }
- errTransform := cmp.Transformer("err", func(e error) string {
- if e == nil {
- return ""
- }
- return e.Error()
- })
- if diff := cmp.Diff(logs[0], test.wantLog, errTransform); diff != "" {
- t.Errorf("handler wrote incorrect request log (-got+want):\n%s", diff)
- }
- if diff := cmp.Diff(rec.Body.String(), test.wantBody); diff != "" {
- t.Errorf("handler wrote incorrect body (-got+want):\n%s", diff)
- }
- })
- }
- }
- func BenchmarkLogNot200(b *testing.B) {
- b.ReportAllocs()
- rh := handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
- // Implicit 200 OK.
- return nil
- })
- h := StdHandler(rh, HandlerOptions{QuietLoggingIfSuccessful: true})
- req := httptest.NewRequest("GET", "/", nil)
- rw := new(httptest.ResponseRecorder)
- for i := 0; i < b.N; i++ {
- *rw = httptest.ResponseRecorder{}
- h.ServeHTTP(rw, req)
- }
- }
- func BenchmarkLog(b *testing.B) {
- b.ReportAllocs()
- rh := handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
- // Implicit 200 OK.
- return nil
- })
- h := StdHandler(rh, HandlerOptions{})
- req := httptest.NewRequest("GET", "/", nil)
- rw := new(httptest.ResponseRecorder)
- for i := 0; i < b.N; i++ {
- *rw = httptest.ResponseRecorder{}
- h.ServeHTTP(rw, req)
- }
- }
- func TestHTTPError_Unwrap(t *testing.T) {
- wrappedErr := fmt.Errorf("wrapped")
- err := Error(404, "not found", wrappedErr)
- if got := errors.Unwrap(err); got != wrappedErr {
- t.Errorf("HTTPError.Unwrap() = %v, want %v", got, wrappedErr)
- }
- }
- func TestAcceptsEncoding(t *testing.T) {
- tests := []struct {
- in, enc string
- want bool
- }{
- {"", "gzip", false},
- {"gzip", "gzip", true},
- {"foo,gzip", "gzip", true},
- {"foo, gzip", "gzip", true},
- {"foo, gzip ", "gzip", true},
- {"gzip, foo ", "gzip", true},
- {"gzip, foo ", "br", false},
- {"gzip, foo ", "fo", false},
- {"gzip;q=1.2, foo ", "gzip", true},
- {" gzip;q=1.2, foo ", "gzip", true},
- }
- for i, tt := range tests {
- h := make(http.Header)
- if tt.in != "" {
- h.Set("Accept-Encoding", tt.in)
- }
- got := AcceptsEncoding(&http.Request{Header: h}, tt.enc)
- if got != tt.want {
- t.Errorf("%d. got %v; want %v", i, got, tt.want)
- }
- }
- }
- func TestPort80Handler(t *testing.T) {
- tests := []struct {
- name string
- h *Port80Handler
- req string
- wantLoc string
- }{
- {
- name: "no_fqdn",
- h: &Port80Handler{},
- req: "GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n",
- wantLoc: "https://foo.com/",
- },
- {
- name: "fqdn_and_path",
- h: &Port80Handler{FQDN: "bar.com"},
- req: "GET /path HTTP/1.1\r\nHost: foo.com\r\n\r\n",
- wantLoc: "https://bar.com/path",
- },
- {
- name: "path_and_query_string",
- h: &Port80Handler{FQDN: "baz.com"},
- req: "GET /path?a=b HTTP/1.1\r\nHost: foo.com\r\n\r\n",
- wantLoc: "https://baz.com/path?a=b",
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- r, _ := http.ReadRequest(bufio.NewReader(strings.NewReader(tt.req)))
- rec := httptest.NewRecorder()
- tt.h.ServeHTTP(rec, r)
- got := rec.Result()
- if got, want := got.StatusCode, 302; got != want {
- t.Errorf("got status code %v; want %v", got, want)
- }
- if got, want := got.Header.Get("Location"), "https://foo.com/"; got != tt.wantLoc {
- t.Errorf("Location = %q; want %q", got, want)
- }
- })
- }
- }
|