jsonhandler.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. // Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tsweb
  5. import (
  6. "bytes"
  7. "compress/gzip"
  8. "encoding/json"
  9. "fmt"
  10. "io/ioutil"
  11. "net/http"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "go4.org/mem"
  16. )
  17. type response struct {
  18. Status string `json:"status"`
  19. Error string `json:"error,omitempty"`
  20. Data interface{} `json:"data,omitempty"`
  21. }
  22. // JSONHandlerFunc is an HTTP ReturnHandler that writes JSON responses to the client.
  23. //
  24. // Return a HTTPError to show an error message, otherwise JSONHandlerFunc will
  25. // only report "internal server error" to the user with status code 500.
  26. type JSONHandlerFunc func(r *http.Request) (status int, data interface{}, err error)
  27. // ServeHTTPReturn implements the ReturnHandler interface.
  28. //
  29. // Use the following code to unmarshal the request body
  30. //
  31. // body := new(DataType)
  32. // if err := json.NewDecoder(r.Body).Decode(body); err != nil {
  33. // return http.StatusBadRequest, nil, err
  34. // }
  35. //
  36. // See jsonhandler_test.go for examples.
  37. func (fn JSONHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) error {
  38. w.Header().Set("Content-Type", "application/json")
  39. var resp *response
  40. status, data, err := fn(r)
  41. if err != nil {
  42. if werr, ok := err.(HTTPError); ok {
  43. resp = &response{
  44. Status: "error",
  45. Error: werr.Msg,
  46. Data: data,
  47. }
  48. // Unwrap the HTTPError here because we are communicating with
  49. // the client in this handler. We don't want the wrapping
  50. // ReturnHandler to do it too.
  51. err = werr.Err
  52. if werr.Msg != "" {
  53. err = fmt.Errorf("%s: %w", werr.Msg, err)
  54. }
  55. // take status from the HTTPError to encourage error handling in one location
  56. if status != 0 && status != werr.Code {
  57. err = fmt.Errorf("[unexpected] non-zero status that does not match HTTPError status, status: %d, HTTPError.code: %d: %w", status, werr.Code, err)
  58. }
  59. status = werr.Code
  60. } else {
  61. status = http.StatusInternalServerError
  62. resp = &response{
  63. Status: "error",
  64. Error: "internal server error",
  65. }
  66. }
  67. } else if status == 0 {
  68. status = http.StatusInternalServerError
  69. resp = &response{
  70. Status: "error",
  71. Error: "internal server error",
  72. }
  73. } else if err == nil {
  74. resp = &response{
  75. Status: "success",
  76. Data: data,
  77. }
  78. }
  79. b, jerr := json.Marshal(resp)
  80. if jerr != nil {
  81. w.WriteHeader(http.StatusInternalServerError)
  82. w.Write([]byte(`{"status":"error","error":"json marshal error"}`))
  83. if err != nil {
  84. return fmt.Errorf("%w, and then we could not respond: %v", err, jerr)
  85. }
  86. return jerr
  87. }
  88. if AcceptsEncoding(r, "gzip") {
  89. encb, err := gzipBytes(b)
  90. if err != nil {
  91. return err
  92. }
  93. w.Header().Set("Content-Encoding", "gzip")
  94. w.Header().Set("Content-Length", strconv.Itoa(len(encb)))
  95. w.WriteHeader(status)
  96. w.Write(encb)
  97. } else {
  98. w.Header().Set("Content-Length", strconv.Itoa(len(b)))
  99. w.WriteHeader(status)
  100. w.Write(b)
  101. }
  102. return err
  103. }
  104. var gzWriterPool sync.Pool // of *gzip.Writer
  105. // gzipBytes returns the gzipped encoding of b.
  106. func gzipBytes(b []byte) (zb []byte, err error) {
  107. var buf bytes.Buffer
  108. zw, ok := gzWriterPool.Get().(*gzip.Writer)
  109. if ok {
  110. zw.Reset(&buf)
  111. } else {
  112. zw = gzip.NewWriter(&buf)
  113. }
  114. defer gzWriterPool.Put(zw)
  115. if _, err := zw.Write(b); err != nil {
  116. return nil, err
  117. }
  118. if err := zw.Close(); err != nil {
  119. return nil, err
  120. }
  121. zb = buf.Bytes()
  122. zw.Reset(ioutil.Discard)
  123. return zb, nil
  124. }
  125. // AcceptsEncoding reports whether r accepts the named encoding
  126. // ("gzip", "br", etc).
  127. func AcceptsEncoding(r *http.Request, enc string) bool {
  128. h := r.Header.Get("Accept-Encoding")
  129. if h == "" {
  130. return false
  131. }
  132. if !strings.Contains(h, enc) && !mem.ContainsFold(mem.S(h), mem.S(enc)) {
  133. return false
  134. }
  135. remain := h
  136. for len(remain) > 0 {
  137. comma := strings.Index(remain, ",")
  138. var part string
  139. if comma == -1 {
  140. part = remain
  141. remain = ""
  142. } else {
  143. part = remain[:comma]
  144. remain = remain[comma+1:]
  145. }
  146. part = strings.TrimSpace(part)
  147. if i := strings.Index(part, ";"); i != -1 {
  148. part = part[:i]
  149. }
  150. if part == enc {
  151. return true
  152. }
  153. }
  154. return false
  155. }