Sfoglia il codice sorgente

tsweb: add gzip support to JSONHandlerFunc

Change-Id: I337e05f92f744bfc7e9d6fb8e67c87c191ba4da8
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 4 anni fa
parent
commit
df8f02db3f
2 ha cambiato i file con 151 aggiunte e 18 eliminazioni
  1. 76 2
      tsweb/jsonhandler.go
  2. 75 16
      tsweb/jsonhandler_test.go

+ 76 - 2
tsweb/jsonhandler.go

@@ -5,9 +5,17 @@
 package tsweb
 
 import (
+	"bytes"
+	"compress/gzip"
 	"encoding/json"
 	"fmt"
+	"io/ioutil"
 	"net/http"
+	"strconv"
+	"strings"
+	"sync"
+
+	"go4.org/mem"
 )
 
 type response struct {
@@ -85,7 +93,73 @@ func (fn JSONHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request
 		return jerr
 	}
 
-	w.WriteHeader(status)
-	w.Write(b)
+	if AcceptsEncoding(r, "gzip") {
+		encb, err := gzipBytes(b)
+		if err != nil {
+			return err
+		}
+		w.Header().Set("Content-Encoding", "gzip")
+		w.Header().Set("Content-Length", strconv.Itoa(len(encb)))
+		w.Write(encb)
+	} else {
+		w.Header().Set("Content-Length", strconv.Itoa(len(b)))
+		w.WriteHeader(status)
+		w.Write(b)
+	}
 	return err
 }
+
+var gzWriterPool sync.Pool // of *gzip.Writer
+
+// gzipBytes returns the gzipped encoding of b.
+func gzipBytes(b []byte) (zb []byte, err error) {
+	var buf bytes.Buffer
+	zw, ok := gzWriterPool.Get().(*gzip.Writer)
+	if ok {
+		zw.Reset(&buf)
+	} else {
+		zw = gzip.NewWriter(&buf)
+	}
+	defer gzWriterPool.Put(zw)
+	if _, err := zw.Write(b); err != nil {
+		return nil, err
+	}
+	if err := zw.Close(); err != nil {
+		return nil, err
+	}
+	zb = buf.Bytes()
+	zw.Reset(ioutil.Discard)
+	return zb, nil
+}
+
+// AcceptsEncoding reports whether r accepts the named encoding
+// ("gzip", "br", etc).
+func AcceptsEncoding(r *http.Request, enc string) bool {
+	h := r.Header.Get("Accept-Encoding")
+	if h == "" {
+		return false
+	}
+	if !strings.Contains(h, enc) && !mem.ContainsFold(mem.S(h), mem.S(enc)) {
+		return false
+	}
+	remain := h
+	for len(remain) > 0 {
+		comma := strings.Index(remain, ",")
+		var part string
+		if comma == -1 {
+			part = remain
+			remain = ""
+		} else {
+			part = remain[:comma]
+			remain = remain[comma+1:]
+		}
+		part = strings.TrimSpace(part)
+		if i := strings.Index(part, ";"); i != -1 {
+			part = part[:i]
+		}
+		if part == enc {
+			return true
+		}
+	}
+	return false
+}

+ 75 - 16
tsweb/jsonhandler_test.go

@@ -5,8 +5,11 @@
 package tsweb
 
 import (
+	"bytes"
+	"compress/gzip"
 	"encoding/json"
 	"fmt"
+	"io"
 	"net/http"
 	"net/http/httptest"
 	"strings"
@@ -27,13 +30,25 @@ type Response struct {
 }
 
 func TestNewJSONHandler(t *testing.T) {
-	checkStatus := func(w *httptest.ResponseRecorder, status string, code int) *Response {
+	checkStatus := func(t *testing.T, w *httptest.ResponseRecorder, status string, code int) *Response {
 		d := &Response{
 			Data: &Data{},
 		}
 
-		t.Logf("%s", w.Body.Bytes())
-		err := json.Unmarshal(w.Body.Bytes(), d)
+		bodyBytes := w.Body.Bytes()
+		if w.Result().Header.Get("Content-Encoding") == "gzip" {
+			zr, err := gzip.NewReader(bytes.NewReader(bodyBytes))
+			if err != nil {
+				t.Fatalf("gzip read error at start: %v", err)
+			}
+			bodyBytes, err = io.ReadAll(zr)
+			if err != nil {
+				t.Fatalf("gzip read error: %v", err)
+			}
+		}
+
+		t.Logf("%s", bodyBytes)
+		err := json.Unmarshal(bodyBytes, d)
 		if err != nil {
 			t.Logf(err.Error())
 			return nil
@@ -64,7 +79,7 @@ func TestNewJSONHandler(t *testing.T) {
 		w := httptest.NewRecorder()
 		r := httptest.NewRequest("GET", "/", nil)
 		h21.ServeHTTPReturn(w, r)
-		checkStatus(w, "success", http.StatusOK)
+		checkStatus(t, w, "success", http.StatusOK)
 	})
 
 	t.Run("403 HTTPError", func(t *testing.T) {
@@ -75,7 +90,7 @@ func TestNewJSONHandler(t *testing.T) {
 		w := httptest.NewRecorder()
 		r := httptest.NewRequest("GET", "/", nil)
 		h.ServeHTTPReturn(w, r)
-		checkStatus(w, "error", http.StatusForbidden)
+		checkStatus(t, w, "error", http.StatusForbidden)
 	})
 
 	h22 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
@@ -86,7 +101,7 @@ func TestNewJSONHandler(t *testing.T) {
 		w := httptest.NewRecorder()
 		r := httptest.NewRequest("GET", "/", nil)
 		h22.ServeHTTPReturn(w, r)
-		checkStatus(w, "success", http.StatusOK)
+		checkStatus(t, w, "success", http.StatusOK)
 	})
 
 	h31 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
@@ -105,21 +120,21 @@ func TestNewJSONHandler(t *testing.T) {
 		w := httptest.NewRecorder()
 		r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`))
 		h31.ServeHTTPReturn(w, r)
-		checkStatus(w, "success", http.StatusOK)
+		checkStatus(t, w, "success", http.StatusOK)
 	})
 
 	t.Run("400 bad json", func(t *testing.T) {
 		w := httptest.NewRecorder()
 		r := httptest.NewRequest("POST", "/", strings.NewReader(`{`))
 		h31.ServeHTTPReturn(w, r)
-		checkStatus(w, "error", http.StatusBadRequest)
+		checkStatus(t, w, "error", http.StatusBadRequest)
 	})
 
 	t.Run("400 post data error", func(t *testing.T) {
 		w := httptest.NewRecorder()
 		r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
 		h31.ServeHTTPReturn(w, r)
-		resp := checkStatus(w, "error", http.StatusBadRequest)
+		resp := checkStatus(t, w, "error", http.StatusBadRequest)
 		if resp.Error != "name is empty" {
 			t.Fatalf("wrong error")
 		}
@@ -144,7 +159,23 @@ func TestNewJSONHandler(t *testing.T) {
 		w := httptest.NewRecorder()
 		r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
 		h32.ServeHTTPReturn(w, r)
-		resp := checkStatus(w, "success", http.StatusOK)
+		resp := checkStatus(t, w, "success", http.StatusOK)
+		t.Log(resp.Data)
+		if resp.Data.Price != 20 {
+			t.Fatalf("wrong price: %d %d", resp.Data.Price, 10)
+		}
+	})
+
+	t.Run("gzipped", func(t *testing.T) {
+		w := httptest.NewRecorder()
+		r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
+		r.Header.Set("Accept-Encoding", "gzip")
+		h32.ServeHTTPReturn(w, r)
+		res := w.Result()
+		if ct := res.Header.Get("Content-Encoding"); ct != "gzip" {
+			t.Fatalf("encoding = %q; want gzip", ct)
+		}
+		resp := checkStatus(t, w, "success", http.StatusOK)
 		t.Log(resp.Data)
 		if resp.Data.Price != 20 {
 			t.Fatalf("wrong price: %d %d", resp.Data.Price, 10)
@@ -155,7 +186,7 @@ func TestNewJSONHandler(t *testing.T) {
 		w := httptest.NewRecorder()
 		r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
 		h32.ServeHTTPReturn(w, r)
-		resp := checkStatus(w, "error", http.StatusBadRequest)
+		resp := checkStatus(t, w, "error", http.StatusBadRequest)
 		if resp.Error != "price is empty" {
 			t.Fatalf("wrong error")
 		}
@@ -165,7 +196,7 @@ func TestNewJSONHandler(t *testing.T) {
 		w := httptest.NewRecorder()
 		r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`))
 		h32.ServeHTTPReturn(w, r)
-		resp := checkStatus(w, "error", http.StatusInternalServerError)
+		resp := checkStatus(t, w, "error", http.StatusInternalServerError)
 		if resp.Error != "internal server error" {
 			t.Fatalf("wrong error")
 		}
@@ -177,7 +208,7 @@ func TestNewJSONHandler(t *testing.T) {
 		JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
 			return http.StatusOK, make(chan int), nil
 		}).ServeHTTPReturn(w, r)
-		resp := checkStatus(w, "error", http.StatusInternalServerError)
+		resp := checkStatus(t, w, "error", http.StatusInternalServerError)
 		if resp.Error != "json marshal error" {
 			t.Fatalf("wrong error")
 		}
@@ -189,7 +220,7 @@ func TestNewJSONHandler(t *testing.T) {
 		JSONHandlerFunc(func(r *http.Request) (status int, data interface{}, err error) {
 			return
 		}).ServeHTTPReturn(w, r)
-		checkStatus(w, "error", http.StatusInternalServerError)
+		checkStatus(t, w, "error", http.StatusInternalServerError)
 	})
 
 	t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError agree", func(t *testing.T) {
@@ -203,7 +234,7 @@ func TestNewJSONHandler(t *testing.T) {
 			Data:   &Data{},
 			Error:  "403 forbidden",
 		}
-		got := checkStatus(w, "error", http.StatusForbidden)
+		got := checkStatus(t, w, "error", http.StatusForbidden)
 		if diff := cmp.Diff(want, got); diff != "" {
 			t.Fatalf(diff)
 		}
@@ -223,9 +254,37 @@ func TestNewJSONHandler(t *testing.T) {
 			Data:   &Data{},
 			Error:  "403 forbidden",
 		}
-		got := checkStatus(w, "error", http.StatusForbidden)
+		got := checkStatus(t, w, "error", http.StatusForbidden)
 		if diff := cmp.Diff(want, got); diff != "" {
 			t.Fatalf("(-want,+got):\n%s", diff)
 		}
 	})
 }
+
+func TestAcceptsEncoding(t *testing.T) {
+	tests := []struct {
+		in, enc string
+		want    bool
+	}{
+		{"", "gzip", false},
+		{"gzip", "gzip", true},
+		{"foo,gzip", "gzip", true},
+		{"foo, gzip", "gzip", true},
+		{"foo, gzip ", "gzip", true},
+		{"gzip, foo ", "gzip", true},
+		{"gzip, foo ", "br", false},
+		{"gzip, foo ", "fo", false},
+		{"gzip;q=1.2, foo ", "gzip", true},
+		{" gzip;q=1.2, foo ", "gzip", true},
+	}
+	for i, tt := range tests {
+		h := make(http.Header)
+		if tt.in != "" {
+			h.Set("Accept-Encoding", tt.in)
+		}
+		got := AcceptsEncoding(&http.Request{Header: h}, tt.enc)
+		if got != tt.want {
+			t.Errorf("%d. got %v; want %v", i, got, tt.want)
+		}
+	}
+}