Browse Source

preserve exact header casing when using httpupgrade (#3427)

* preserve exact header casing when using httpupgrade

* fix capitalization of websocket

* oops, we dont need net/url either

* restore old codepath when there are no headers
mmmray 1 year ago
parent
commit
980236f2b6

+ 62 - 16
transport/internet/httpupgrade/dialer.go

@@ -65,23 +65,69 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings *
 		requestURL.Scheme = "http"
 	}
 
-	requestURL.Host = dest.NetAddr()
-	requestURL.Path = transportConfiguration.GetNormalizedPath()
-	req := &http.Request{
-		Method: http.MethodGet,
-		URL:    &requestURL,
-		Host:   transportConfiguration.Host,
-		Header: make(http.Header),
-	}
-	for key, value := range transportConfiguration.Header {
-		req.Header.Add(key, value)
-	}
-	req.Header.Set("Connection", "upgrade")
-	req.Header.Set("Upgrade", "websocket")
+	var req *http.Request = nil
 
-	err = req.Write(conn)
-	if err != nil {
-		return nil, err
+	if len(transportConfiguration.Header) == 0 {
+		requestURL.Host = dest.NetAddr()
+		requestURL.Path = transportConfiguration.GetNormalizedPath()
+		req = &http.Request{
+			Method: http.MethodGet,
+			URL:    &requestURL,
+			Host:   transportConfiguration.Host,
+			Header: make(http.Header),
+		}
+
+		req.Header.Set("Connection", "upgrade")
+		req.Header.Set("Upgrade", "websocket")
+
+		err = req.Write(conn)
+		if err != nil {
+			return nil, err
+		}
+	} else {
+		var headersBuilder strings.Builder
+
+		headersBuilder.WriteString("GET ")
+		headersBuilder.WriteString(transportConfiguration.GetNormalizedPath())
+		headersBuilder.WriteString(" HTTP/1.1\r\n")
+		hasConnectionHeader := false
+		hasUpgradeHeader := false
+		hasHostHeader := false
+		for key, value := range transportConfiguration.Header {
+			if strings.ToLower(key) == "connection" {
+				hasConnectionHeader = true
+			}
+			if strings.ToLower(key) == "upgrade" {
+				hasUpgradeHeader = true
+			}
+			if strings.ToLower(key) == "host" {
+				hasHostHeader = true
+			}
+			headersBuilder.WriteString(key)
+			headersBuilder.WriteString(": ")
+			headersBuilder.WriteString(value)
+			headersBuilder.WriteString("\r\n")
+		}
+
+		if !hasConnectionHeader {
+			headersBuilder.WriteString("Connection: upgrade\r\n")
+		}
+
+		if !hasUpgradeHeader {
+			headersBuilder.WriteString("Upgrade: websocket\r\n")
+		}
+
+		if !hasHostHeader {
+			headersBuilder.WriteString("Host: ")
+			headersBuilder.WriteString(transportConfiguration.Host)
+			headersBuilder.WriteString("\r\n")
+		}
+
+		headersBuilder.WriteString("\r\n")
+		_, err = conn.Write([]byte(headersBuilder.String()))
+		if err != nil {
+			return nil, err
+		}
 	}
 
 	connRF := &ConnRF{

+ 60 - 1
transport/internet/httpupgrade/httpupgrade_test.go

@@ -72,6 +72,65 @@ func Test_listenHTTPUpgradeAndDial(t *testing.T) {
 	common.Must(listen.Close())
 }
 
+func Test_listenHTTPUpgradeAndDialWithHeaders(t *testing.T) {
+	listenPort := tcp.PickPort()
+	listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
+		ProtocolName: "httpupgrade",
+		ProtocolSettings: &Config{
+			Path: "httpupgrade",
+			Header: map[string]string{
+				"User-Agent": "Mozilla",
+			},
+		},
+	}, func(conn stat.Connection) {
+		go func(c stat.Connection) {
+			defer c.Close()
+
+			var b [1024]byte
+			_, err := c.Read(b[:])
+			if err != nil {
+				return
+			}
+
+			common.Must2(c.Write([]byte("Response")))
+		}(conn)
+	})
+	common.Must(err)
+
+	ctx := context.Background()
+	streamSettings := &internet.MemoryStreamConfig{
+		ProtocolName:     "httpupgrade",
+		ProtocolSettings: &Config{Path: "httpupgrade"},
+	}
+	conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
+
+	common.Must(err)
+	_, err = conn.Write([]byte("Test connection 1"))
+	common.Must(err)
+
+	var b [1024]byte
+	n, err := conn.Read(b[:])
+	common.Must(err)
+	if string(b[:n]) != "Response" {
+		t.Error("response: ", string(b[:n]))
+	}
+
+	common.Must(conn.Close())
+	<-time.After(time.Second * 5)
+	conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
+	common.Must(err)
+	_, err = conn.Write([]byte("Test connection 2"))
+	common.Must(err)
+	n, err = conn.Read(b[:])
+	common.Must(err)
+	if string(b[:n]) != "Response" {
+		t.Error("response: ", string(b[:n]))
+	}
+	common.Must(conn.Close())
+
+	common.Must(listen.Close())
+}
+
 func TestDialWithRemoteAddr(t *testing.T) {
 	listenPort := tcp.PickPort()
 	listen, err := ListenHTTPUpgrade(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
@@ -150,4 +209,4 @@ func Test_listenHTTPUpgradeAndDial_TLS(t *testing.T) {
 	if !end.Before(start.Add(time.Second * 5)) {
 		t.Error("end: ", end, " start: ", start)
 	}
-}
+}