Browse Source

feat(stdiscosrv): make compression optional (and faster)

Jakob Borg 1 year ago
parent
commit
68a1fd010f
3 changed files with 103 additions and 4 deletions
  1. 12 3
      cmd/stdiscosrv/apisrv.go
  2. 88 0
      cmd/stdiscosrv/apisrv_test.go
  3. 3 1
      cmd/stdiscosrv/main.go

+ 12 - 3
cmd/stdiscosrv/apisrv.go

@@ -45,7 +45,9 @@ type apiSrv struct {
 	listener       net.Listener
 	repl           replicator // optional
 	useHTTP        bool
+	compression    bool
 	missesIncrease int
+	gzipWriters    sync.Pool
 
 	mapsMut sync.Mutex
 	misses  map[string]int32
@@ -61,13 +63,14 @@ type contextKey int
 
 const idKey contextKey = iota
 
-func newAPISrv(addr string, cert tls.Certificate, db database, repl replicator, useHTTP bool, missesIncrease int) *apiSrv {
+func newAPISrv(addr string, cert tls.Certificate, db database, repl replicator, useHTTP, compression bool, missesIncrease int) *apiSrv {
 	return &apiSrv{
 		addr:           addr,
 		cert:           cert,
 		db:             db,
 		repl:           repl,
 		useHTTP:        useHTTP,
+		compression:    compression,
 		misses:         make(map[string]int32),
 		missesIncrease: missesIncrease,
 	}
@@ -226,10 +229,16 @@ func (s *apiSrv) handleGET(w http.ResponseWriter, req *http.Request) {
 	var bw io.Writer = w
 
 	// Use compression if the client asks for it
-	if strings.Contains(req.Header.Get("Accept-Encoding"), "gzip") {
+	if s.compression && strings.Contains(req.Header.Get("Accept-Encoding"), "gzip") {
+		gw, ok := s.gzipWriters.Get().(*gzip.Writer)
+		if ok {
+			gw.Reset(w)
+		} else {
+			gw = gzip.NewWriter(w)
+		}
 		w.Header().Set("Content-Encoding", "gzip")
-		gw := gzip.NewWriter(bw)
 		defer gw.Close()
+		defer s.gzipWriters.Put(gw)
 		bw = gw
 	}
 

+ 88 - 0
cmd/stdiscosrv/apisrv_test.go

@@ -7,9 +7,19 @@
 package main
 
 import (
+	"context"
+	"crypto/tls"
 	"fmt"
+	"io"
 	"net"
+	"net/http"
+	"net/http/httptest"
+	"os"
+	"strings"
 	"testing"
+
+	"github.com/syncthing/syncthing/lib/protocol"
+	"github.com/syncthing/syncthing/lib/tlsutil"
 )
 
 func TestFixupAddresses(t *testing.T) {
@@ -94,3 +104,81 @@ func addr(host string, port int) *net.TCPAddr {
 		Port: port,
 	}
 }
+
+func BenchmarkAPIRequests(b *testing.B) {
+	db, err := newLevelDBStore(b.TempDir())
+	if err != nil {
+		b.Fatal(err)
+	}
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	go db.Serve(ctx)
+	api := newAPISrv("127.0.0.1:0", tls.Certificate{}, db, nil, true, true, 1)
+	srv := httptest.NewServer(http.HandlerFunc(api.handler))
+
+	kf := b.TempDir() + "/cert"
+	crt, err := tlsutil.NewCertificate(kf+".crt", kf+".key", "localhost", 7)
+	if err != nil {
+		b.Fatal(err)
+	}
+	certBs, err := os.ReadFile(kf + ".crt")
+	if err != nil {
+		b.Fatal(err)
+	}
+	certString := string(strings.ReplaceAll(string(certBs), "\n", " "))
+
+	devID := protocol.NewDeviceID(crt.Certificate[0])
+	devIDString := devID.String()
+
+	b.Run("Announce", func(b *testing.B) {
+		b.ReportAllocs()
+		url := srv.URL + "/v2/?device=" + devIDString
+		for i := 0; i < b.N; i++ {
+			req, _ := http.NewRequest(http.MethodPost, url, strings.NewReader(`{"addresses":["tcp://10.10.10.10:42000"]}`))
+			req.Header.Set("X-Ssl-Cert", certString)
+			resp, err := http.DefaultClient.Do(req)
+			if err != nil {
+				b.Fatal(err)
+			}
+			resp.Body.Close()
+			if resp.StatusCode != http.StatusNoContent {
+				b.Fatalf("unexpected status %s", resp.Status)
+			}
+		}
+	})
+
+	b.Run("Lookup", func(b *testing.B) {
+		b.ReportAllocs()
+		url := srv.URL + "/v2/?device=" + devIDString
+		for i := 0; i < b.N; i++ {
+			req, _ := http.NewRequest(http.MethodGet, url, nil)
+			resp, err := http.DefaultClient.Do(req)
+			if err != nil {
+				b.Fatal(err)
+			}
+			io.Copy(io.Discard, resp.Body)
+			resp.Body.Close()
+			if resp.StatusCode != http.StatusOK {
+				b.Fatalf("unexpected status %s", resp.Status)
+			}
+		}
+	})
+
+	b.Run("LookupNoCompression", func(b *testing.B) {
+		b.ReportAllocs()
+		url := srv.URL + "/v2/?device=" + devIDString
+		for i := 0; i < b.N; i++ {
+			req, _ := http.NewRequest(http.MethodGet, url, nil)
+			req.Header.Set("Accept-Encoding", "identity") // disable compression
+			resp, err := http.DefaultClient.Do(req)
+			if err != nil {
+				b.Fatal(err)
+			}
+			io.Copy(io.Discard, resp.Body)
+			resp.Body.Close()
+			if resp.StatusCode != http.StatusOK {
+				b.Fatalf("unexpected status %s", resp.Status)
+			}
+		}
+	})
+}

+ 3 - 1
cmd/stdiscosrv/main.go

@@ -80,6 +80,7 @@ func main() {
 	var replCertFile string
 	var replKeyFile string
 	var useHTTP bool
+	var compression bool
 	var largeDB bool
 	var amqpAddress string
 	missesIncrease := 1
@@ -92,6 +93,7 @@ func main() {
 	flag.StringVar(&dir, "db-dir", "./discovery.db", "Database directory")
 	flag.BoolVar(&debug, "debug", false, "Print debug output")
 	flag.BoolVar(&useHTTP, "http", false, "Listen on HTTP (behind an HTTPS proxy)")
+	flag.BoolVar(&compression, "compression", true, "Enable GZIP compression of responses")
 	flag.StringVar(&listen, "listen", ":8443", "Listen address")
 	flag.StringVar(&metricsListen, "metrics-listen", "", "Metrics listen address")
 	flag.StringVar(&replicationPeers, "replicate", "", "Replication peers, id@address, comma separated")
@@ -225,7 +227,7 @@ func main() {
 	}()
 
 	// Start the main API server.
-	qs := newAPISrv(listen, cert, db, repl, useHTTP, missesIncrease)
+	qs := newAPISrv(listen, cert, db, repl, useHTTP, compression, missesIncrease)
 	main.Add(qs)
 
 	// If we have a metrics port configured, start a metrics handler.