| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- // Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package tsweb
- import (
- "bytes"
- "compress/gzip"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "github.com/google/go-cmp/cmp"
- )
- type Data struct {
- Name string
- Price int
- }
- type Response struct {
- Status string
- Error string
- Data *Data
- }
- func TestNewJSONHandler(t *testing.T) {
- checkStatus := func(t *testing.T, w *httptest.ResponseRecorder, status string, code int) *Response {
- d := &Response{
- Data: &Data{},
- }
- bodyBytes := w.Body.Bytes()
- if w.Result().Header.Get("Content-Encoding") == "gzip" {
- zr, err := gzip.NewReader(bytes.NewReader(bodyBytes))
- if err != nil {
- t.Fatalf("gzip read error at start: %v", err)
- }
- bodyBytes, err = io.ReadAll(zr)
- if err != nil {
- t.Fatalf("gzip read error: %v", err)
- }
- }
- t.Logf("%s", bodyBytes)
- err := json.Unmarshal(bodyBytes, d)
- if err != nil {
- t.Logf(err.Error())
- return nil
- }
- if d.Status == status {
- t.Logf("ok: %s", d.Status)
- } else {
- t.Fatalf("wrong status: got: %s, want: %s", d.Status, status)
- }
- if w.Code != code {
- t.Fatalf("wrong status code: got: %d, want: %d", w.Code, code)
- }
- if w.Header().Get("Content-Type") != "application/json" {
- t.Fatalf("wrong content type: %s", w.Header().Get("Content-Type"))
- }
- return d
- }
- h21 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
- return http.StatusOK, nil, nil
- })
- t.Run("200 simple", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest("GET", "/", nil)
- h21.ServeHTTPReturn(w, r)
- checkStatus(t, w, "success", http.StatusOK)
- })
- t.Run("403 HTTPError", func(t *testing.T) {
- h := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
- return 0, nil, Error(http.StatusForbidden, "forbidden", nil)
- })
- w := httptest.NewRecorder()
- r := httptest.NewRequest("GET", "/", nil)
- h.ServeHTTPReturn(w, r)
- checkStatus(t, w, "error", http.StatusForbidden)
- })
- h22 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
- return http.StatusOK, &Data{Name: "tailscale"}, nil
- })
- t.Run("200 get data", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest("GET", "/", nil)
- h22.ServeHTTPReturn(w, r)
- checkStatus(t, w, "success", http.StatusOK)
- })
- h31 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
- body := new(Data)
- if err := json.NewDecoder(r.Body).Decode(body); err != nil {
- return 0, nil, Error(http.StatusBadRequest, err.Error(), err)
- }
- if body.Name == "" {
- return 0, nil, Error(http.StatusBadRequest, "name is empty", nil)
- }
- return http.StatusOK, nil, nil
- })
- t.Run("200 post data", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`))
- h31.ServeHTTPReturn(w, r)
- checkStatus(t, w, "success", http.StatusOK)
- })
- t.Run("400 bad json", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest("POST", "/", strings.NewReader(`{`))
- h31.ServeHTTPReturn(w, r)
- checkStatus(t, w, "error", http.StatusBadRequest)
- })
- t.Run("400 post data error", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
- h31.ServeHTTPReturn(w, r)
- resp := checkStatus(t, w, "error", http.StatusBadRequest)
- if resp.Error != "name is empty" {
- t.Fatalf("wrong error")
- }
- })
- h32 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
- body := new(Data)
- if err := json.NewDecoder(r.Body).Decode(body); err != nil {
- return 0, nil, Error(http.StatusBadRequest, err.Error(), err)
- }
- if body.Name == "root" {
- return 0, nil, fmt.Errorf("invalid name")
- }
- if body.Price == 0 {
- return 0, nil, Error(http.StatusBadRequest, "price is empty", nil)
- }
- return http.StatusOK, &Data{Price: body.Price * 2}, nil
- })
- t.Run("200 post data", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
- h32.ServeHTTPReturn(w, r)
- resp := checkStatus(t, w, "success", http.StatusOK)
- t.Log(resp.Data)
- if resp.Data.Price != 20 {
- t.Fatalf("wrong price: %d %d", resp.Data.Price, 10)
- }
- })
- t.Run("gzipped", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
- r.Header.Set("Accept-Encoding", "gzip")
- h32.ServeHTTPReturn(w, r)
- res := w.Result()
- if ct := res.Header.Get("Content-Encoding"); ct != "gzip" {
- t.Fatalf("encoding = %q; want gzip", ct)
- }
- resp := checkStatus(t, w, "success", http.StatusOK)
- t.Log(resp.Data)
- if resp.Data.Price != 20 {
- t.Fatalf("wrong price: %d %d", resp.Data.Price, 10)
- }
- })
- t.Run("400 post data error", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
- h32.ServeHTTPReturn(w, r)
- resp := checkStatus(t, w, "error", http.StatusBadRequest)
- if resp.Error != "price is empty" {
- t.Fatalf("wrong error")
- }
- })
- t.Run("500 internal server error (unspecified error, not of type HTTPError)", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`))
- h32.ServeHTTPReturn(w, r)
- resp := checkStatus(t, w, "error", http.StatusInternalServerError)
- if resp.Error != "internal server error" {
- t.Fatalf("wrong error")
- }
- })
- t.Run("500 misuse", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest("POST", "/", nil)
- JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
- return http.StatusOK, make(chan int), nil
- }).ServeHTTPReturn(w, r)
- resp := checkStatus(t, w, "error", http.StatusInternalServerError)
- if resp.Error != "json marshal error" {
- t.Fatalf("wrong error")
- }
- })
- t.Run("500 empty status code", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest("POST", "/", nil)
- JSONHandlerFunc(func(r *http.Request) (status int, data interface{}, err error) {
- return
- }).ServeHTTPReturn(w, r)
- checkStatus(t, w, "error", http.StatusInternalServerError)
- })
- t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError agree", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest("POST", "/", nil)
- JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
- return http.StatusForbidden, nil, Error(http.StatusForbidden, "403 forbidden", nil)
- }).ServeHTTPReturn(w, r)
- want := &Response{
- Status: "error",
- Data: &Data{},
- Error: "403 forbidden",
- }
- got := checkStatus(t, w, "error", http.StatusForbidden)
- if diff := cmp.Diff(want, got); diff != "" {
- t.Fatalf(diff)
- }
- })
- t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError do not agree", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest("POST", "/", nil)
- err := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
- return http.StatusInternalServerError, nil, Error(http.StatusForbidden, "403 forbidden", nil)
- }).ServeHTTPReturn(w, r)
- if !strings.HasPrefix(err.Error(), "[unexpected]") {
- t.Fatalf("returned error should have `[unexpected]` to note the disagreeing status codes: %v", err)
- }
- want := &Response{
- Status: "error",
- Data: &Data{},
- Error: "403 forbidden",
- }
- got := checkStatus(t, w, "error", http.StatusForbidden)
- if diff := cmp.Diff(want, got); diff != "" {
- t.Fatalf("(-want,+got):\n%s", diff)
- }
- })
- }
- func TestAcceptsEncoding(t *testing.T) {
- tests := []struct {
- in, enc string
- want bool
- }{
- {"", "gzip", false},
- {"gzip", "gzip", true},
- {"foo,gzip", "gzip", true},
- {"foo, gzip", "gzip", true},
- {"foo, gzip ", "gzip", true},
- {"gzip, foo ", "gzip", true},
- {"gzip, foo ", "br", false},
- {"gzip, foo ", "fo", false},
- {"gzip;q=1.2, foo ", "gzip", true},
- {" gzip;q=1.2, foo ", "gzip", true},
- }
- for i, tt := range tests {
- h := make(http.Header)
- if tt.in != "" {
- h.Set("Accept-Encoding", tt.in)
- }
- got := AcceptsEncoding(&http.Request{Header: h}, tt.enc)
- if got != tt.want {
- t.Errorf("%d. got %v; want %v", i, got, tt.want)
- }
- }
- }
|