Jelajahi Sumber

chore(stdiscosrv): reduce allocations in cert handling

Jakob Borg 1 tahun lalu
induk
melakukan
f9b72330a8
2 mengubah file dengan 29 tambahan dan 10 penghapusan
  1. 26 9
      cmd/stdiscosrv/apisrv.go
  2. 3 1
      cmd/stdiscosrv/apisrv_test.go

+ 26 - 9
cmd/stdiscosrv/apisrv.go

@@ -367,7 +367,7 @@ func certificateBytes(req *http.Request) ([]byte, error) {
 		}
 
 		bs = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: hdr})
-	} else if hdr := req.Header.Get("X-Forwarded-Tls-Client-Cert"); hdr != "" {
+	} else if cert := req.Header.Get("X-Forwarded-Tls-Client-Cert"); cert != "" {
 		// Traefik 2 passtlsclientcert
 		//
 		// The certificate is in PEM format, maybe with URL encoding
@@ -375,19 +375,36 @@ func certificateBytes(req *http.Request) ([]byte, error) {
 		// statements. We need to decode, reinstate the newlines every 64
 		// character and add statements for the PEM decoder
 
-		if strings.Contains(hdr, "%") {
-			if unesc, err := url.QueryUnescape(hdr); err == nil {
-				hdr = unesc
+		if strings.Contains(cert, "%") {
+			if unesc, err := url.QueryUnescape(cert); err == nil {
+				cert = unesc
 			}
 		}
 
-		for i := 64; i < len(hdr); i += 65 {
-			hdr = hdr[:i] + "\n" + hdr[i:]
+		const (
+			header = "-----BEGIN CERTIFICATE-----"
+			footer = "-----END CERTIFICATE-----"
+		)
+
+		var b bytes.Buffer
+		b.Grow(len(header) + 1 + len(cert) + len(cert)/64 + 1 + len(footer) + 1)
+
+		b.WriteString(header)
+		b.WriteByte('\n')
+
+		for i := 0; i < len(cert); i += 64 {
+			end := i + 64
+			if end > len(cert) {
+				end = len(cert)
+			}
+			b.WriteString(cert[i:end])
+			b.WriteByte('\n')
 		}
 
-		hdr = "-----BEGIN CERTIFICATE-----\n" + hdr
-		hdr += "\n-----END CERTIFICATE-----\n"
-		bs = []byte(hdr)
+		b.WriteString(footer)
+		b.WriteByte('\n')
+
+		bs = b.Bytes()
 	}
 
 	if bs == nil {

+ 3 - 1
cmd/stdiscosrv/apisrv_test.go

@@ -15,6 +15,7 @@ import (
 	"net/http"
 	"net/http/httptest"
 	"os"
+	"regexp"
 	"strings"
 	"testing"
 
@@ -122,6 +123,7 @@ func BenchmarkAPIRequests(b *testing.B) {
 	if err != nil {
 		b.Fatal(err)
 	}
+	certBs = regexp.MustCompile(`---[^\n]+---\n`).ReplaceAll(certBs, nil)
 	certString := string(strings.ReplaceAll(string(certBs), "\n", " "))
 
 	devID := protocol.NewDeviceID(crt.Certificate[0])
@@ -132,7 +134,7 @@ func BenchmarkAPIRequests(b *testing.B) {
 		url := srv.URL + "/v2/?device=" + devIDString
 		for i := 0; i < b.N; i++ {
 			req, _ := http.NewRequest(http.MethodPost, url, strings.NewReader(`{"addresses":["tcp://10.10.10.10:42000"]}`))
-			req.Header.Set("X-Ssl-Cert", certString)
+			req.Header.Set("X-Forwarded-Tls-Client-Cert", certString)
 			resp, err := http.DefaultClient.Do(req)
 			if err != nil {
 				b.Fatal(err)