jsonhandler_test.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. // Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tsweb
  5. import (
  6. "encoding/json"
  7. "fmt"
  8. "net/http"
  9. "net/http/httptest"
  10. "strings"
  11. "testing"
  12. "github.com/google/go-cmp/cmp"
  13. )
  14. type Data struct {
  15. Name string
  16. Price int
  17. }
  18. type Response struct {
  19. Status string
  20. Error string
  21. Data *Data
  22. }
  23. func TestNewJSONHandler(t *testing.T) {
  24. checkStatus := func(w *httptest.ResponseRecorder, status string, code int) *Response {
  25. d := &Response{
  26. Data: &Data{},
  27. }
  28. t.Logf("%s", w.Body.Bytes())
  29. err := json.Unmarshal(w.Body.Bytes(), d)
  30. if err != nil {
  31. t.Logf(err.Error())
  32. return nil
  33. }
  34. if d.Status == status {
  35. t.Logf("ok: %s", d.Status)
  36. } else {
  37. t.Fatalf("wrong status: got: %s, want: %s", d.Status, status)
  38. }
  39. if w.Code != code {
  40. t.Fatalf("wrong status code: got: %d, want: %d", w.Code, code)
  41. }
  42. if w.Header().Get("Content-Type") != "application/json" {
  43. t.Fatalf("wrong content type: %s", w.Header().Get("Content-Type"))
  44. }
  45. return d
  46. }
  47. h21 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
  48. return http.StatusOK, nil, nil
  49. })
  50. t.Run("200 simple", func(t *testing.T) {
  51. w := httptest.NewRecorder()
  52. r := httptest.NewRequest("GET", "/", nil)
  53. h21.ServeHTTPReturn(w, r)
  54. checkStatus(w, "success", http.StatusOK)
  55. })
  56. t.Run("403 HTTPError", func(t *testing.T) {
  57. h := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
  58. return 0, nil, Error(http.StatusForbidden, "forbidden", nil)
  59. })
  60. w := httptest.NewRecorder()
  61. r := httptest.NewRequest("GET", "/", nil)
  62. h.ServeHTTPReturn(w, r)
  63. checkStatus(w, "error", http.StatusForbidden)
  64. })
  65. h22 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
  66. return http.StatusOK, &Data{Name: "tailscale"}, nil
  67. })
  68. t.Run("200 get data", func(t *testing.T) {
  69. w := httptest.NewRecorder()
  70. r := httptest.NewRequest("GET", "/", nil)
  71. h22.ServeHTTPReturn(w, r)
  72. checkStatus(w, "success", http.StatusOK)
  73. })
  74. h31 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
  75. body := new(Data)
  76. if err := json.NewDecoder(r.Body).Decode(body); err != nil {
  77. return 0, nil, Error(http.StatusBadRequest, err.Error(), err)
  78. }
  79. if body.Name == "" {
  80. return 0, nil, Error(http.StatusBadRequest, "name is empty", nil)
  81. }
  82. return http.StatusOK, nil, nil
  83. })
  84. t.Run("200 post data", func(t *testing.T) {
  85. w := httptest.NewRecorder()
  86. r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`))
  87. h31.ServeHTTPReturn(w, r)
  88. checkStatus(w, "success", http.StatusOK)
  89. })
  90. t.Run("400 bad json", func(t *testing.T) {
  91. w := httptest.NewRecorder()
  92. r := httptest.NewRequest("POST", "/", strings.NewReader(`{`))
  93. h31.ServeHTTPReturn(w, r)
  94. checkStatus(w, "error", http.StatusBadRequest)
  95. })
  96. t.Run("400 post data error", func(t *testing.T) {
  97. w := httptest.NewRecorder()
  98. r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
  99. h31.ServeHTTPReturn(w, r)
  100. resp := checkStatus(w, "error", http.StatusBadRequest)
  101. if resp.Error != "name is empty" {
  102. t.Fatalf("wrong error")
  103. }
  104. })
  105. h32 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
  106. body := new(Data)
  107. if err := json.NewDecoder(r.Body).Decode(body); err != nil {
  108. return 0, nil, Error(http.StatusBadRequest, err.Error(), err)
  109. }
  110. if body.Name == "root" {
  111. return 0, nil, fmt.Errorf("invalid name")
  112. }
  113. if body.Price == 0 {
  114. return 0, nil, Error(http.StatusBadRequest, "price is empty", nil)
  115. }
  116. return http.StatusOK, &Data{Price: body.Price * 2}, nil
  117. })
  118. t.Run("200 post data", func(t *testing.T) {
  119. w := httptest.NewRecorder()
  120. r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
  121. h32.ServeHTTPReturn(w, r)
  122. resp := checkStatus(w, "success", http.StatusOK)
  123. t.Log(resp.Data)
  124. if resp.Data.Price != 20 {
  125. t.Fatalf("wrong price: %d %d", resp.Data.Price, 10)
  126. }
  127. })
  128. t.Run("400 post data error", func(t *testing.T) {
  129. w := httptest.NewRecorder()
  130. r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
  131. h32.ServeHTTPReturn(w, r)
  132. resp := checkStatus(w, "error", http.StatusBadRequest)
  133. if resp.Error != "price is empty" {
  134. t.Fatalf("wrong error")
  135. }
  136. })
  137. t.Run("500 internal server error (unspecified error, not of type HTTPError)", func(t *testing.T) {
  138. w := httptest.NewRecorder()
  139. r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`))
  140. h32.ServeHTTPReturn(w, r)
  141. resp := checkStatus(w, "error", http.StatusInternalServerError)
  142. if resp.Error != "internal server error" {
  143. t.Fatalf("wrong error")
  144. }
  145. })
  146. t.Run("500 misuse", func(t *testing.T) {
  147. w := httptest.NewRecorder()
  148. r := httptest.NewRequest("POST", "/", nil)
  149. JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
  150. return http.StatusOK, make(chan int), nil
  151. }).ServeHTTPReturn(w, r)
  152. resp := checkStatus(w, "error", http.StatusInternalServerError)
  153. if resp.Error != "json marshal error" {
  154. t.Fatalf("wrong error")
  155. }
  156. })
  157. t.Run("500 empty status code", func(t *testing.T) {
  158. w := httptest.NewRecorder()
  159. r := httptest.NewRequest("POST", "/", nil)
  160. JSONHandlerFunc(func(r *http.Request) (status int, data interface{}, err error) {
  161. return
  162. }).ServeHTTPReturn(w, r)
  163. checkStatus(w, "error", http.StatusInternalServerError)
  164. })
  165. t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError agree", func(t *testing.T) {
  166. w := httptest.NewRecorder()
  167. r := httptest.NewRequest("POST", "/", nil)
  168. JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
  169. return http.StatusForbidden, nil, Error(http.StatusForbidden, "403 forbidden", nil)
  170. }).ServeHTTPReturn(w, r)
  171. want := &Response{
  172. Status: "error",
  173. Data: &Data{},
  174. Error: "403 forbidden",
  175. }
  176. got := checkStatus(w, "error", http.StatusForbidden)
  177. if diff := cmp.Diff(want, got); diff != "" {
  178. t.Fatalf(diff)
  179. }
  180. })
  181. t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError do not agree", func(t *testing.T) {
  182. w := httptest.NewRecorder()
  183. r := httptest.NewRequest("POST", "/", nil)
  184. err := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
  185. return http.StatusInternalServerError, nil, Error(http.StatusForbidden, "403 forbidden", nil)
  186. }).ServeHTTPReturn(w, r)
  187. if !strings.HasPrefix(err.Error(), "[unexpected]") {
  188. t.Fatalf("returned error should have `[unexpected]` to note the disagreeing status codes: %v", err)
  189. }
  190. want := &Response{
  191. Status: "error",
  192. Data: &Data{},
  193. Error: "403 forbidden",
  194. }
  195. got := checkStatus(w, "error", http.StatusForbidden)
  196. if diff := cmp.Diff(want, got); diff != "" {
  197. t.Fatalf("(-want,+got):\n%s", diff)
  198. }
  199. })
  200. }