Browse Source

HTTPUpgrade send headers with specified capitalization (#3430)

* Fix HTTPUpgrade header capitalization

* Chore

* Remove excess host headers

Chore : change httpupgrade header "upgrade" to "Upgrade" #3435
风扇滑翔翼 1 year ago
parent
commit
3654c0d710

+ 2 - 0
infra/conf/transport_internet.go

@@ -208,8 +208,10 @@ func (c *HttpUpgradeConfig) Build() (proto.Message, error) {
 	// Host priority: Host field > headers field > address.
 	if c.Host == "" && c.Headers["host"] != "" {
 		c.Host = c.Headers["host"]
+		delete(c.Headers,"host")
 	} else if c.Host == "" && c.Headers["Host"] != "" {
 		c.Host = c.Headers["Host"]
+		delete(c.Headers,"Host")
 	}
 	config := &httpupgrade.Config{
 		Path:                path,

+ 23 - 62
transport/internet/httpupgrade/dialer.go

@@ -69,69 +69,23 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings *
 		requestURL.Scheme = "http"
 	}
 
-	var req *http.Request = nil
-
-	if len(transportConfiguration.Header) == 0 {
-		requestURL.Host = dest.NetAddr()
-		requestURL.Path = transportConfiguration.GetNormalizedPath()
-		req = &http.Request{
-			Method: http.MethodGet,
-			URL:    &requestURL,
-			Host:   transportConfiguration.Host,
-			Header: make(http.Header),
-		}
-
-		req.Header.Set("Connection", "upgrade")
-		req.Header.Set("Upgrade", "websocket")
-
-		err = req.Write(conn)
-		if err != nil {
-			return nil, err
-		}
-	} else {
-		var headersBuilder strings.Builder
-
-		headersBuilder.WriteString("GET ")
-		headersBuilder.WriteString(transportConfiguration.GetNormalizedPath())
-		headersBuilder.WriteString(" HTTP/1.1\r\n")
-		hasConnectionHeader := false
-		hasUpgradeHeader := false
-		hasHostHeader := false
-		for key, value := range transportConfiguration.Header {
-			if strings.ToLower(key) == "connection" {
-				hasConnectionHeader = true
-			}
-			if strings.ToLower(key) == "upgrade" {
-				hasUpgradeHeader = true
-			}
-			if strings.ToLower(key) == "host" {
-				hasHostHeader = true
-			}
-			headersBuilder.WriteString(key)
-			headersBuilder.WriteString(": ")
-			headersBuilder.WriteString(value)
-			headersBuilder.WriteString("\r\n")
-		}
-
-		if !hasConnectionHeader {
-			headersBuilder.WriteString("Connection: upgrade\r\n")
-		}
-
-		if !hasUpgradeHeader {
-			headersBuilder.WriteString("Upgrade: websocket\r\n")
-		}
-
-		if !hasHostHeader {
-			headersBuilder.WriteString("Host: ")
-			headersBuilder.WriteString(transportConfiguration.Host)
-			headersBuilder.WriteString("\r\n")
-		}
+	requestURL.Host = dest.NetAddr()
+	requestURL.Path = transportConfiguration.GetNormalizedPath()
+	req := &http.Request{
+		Method: http.MethodGet,
+		URL:    &requestURL,
+		Host:   transportConfiguration.Host,
+		Header: make(http.Header),
+	}
+	for key, value := range transportConfiguration.Header {
+		AddHeader(req.Header, key, value)
+	}
+	req.Header.Set("Connection", "Upgrade")
+	req.Header.Set("Upgrade", "websocket")
 
-		headersBuilder.WriteString("\r\n")
-		_, err = conn.Write([]byte(headersBuilder.String()))
-		if err != nil {
-			return nil, err
-		}
+	err = req.Write(conn)
+	if err != nil {
+		return nil, err
 	}
 
 	connRF := &ConnRF{
@@ -150,6 +104,13 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings *
 	return connRF, nil
 }
 
+//http.Header.Add() will convert headers to MIME header format.
+//Some people don't like this because they want to send "Web*S*ocket".
+//So we add a simple function to replace that method.
+func AddHeader(header http.Header, key, value string) {
+	header[key] = append(header[key], value)
+}
+
 func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
 	newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
 

+ 1 - 1
transport/internet/httpupgrade/hub.go

@@ -62,7 +62,7 @@ func (s *server) Handle(conn net.Conn) (stat.Connection, error) {
 		ProtoMinor: 1,
 		Header:     http.Header{},
 	}
-	resp.Header.Set("Connection", "upgrade")
+	resp.Header.Set("Connection", "Upgrade")
 	resp.Header.Set("Upgrade", "websocket")
 	err = resp.Write(conn)
 	if err != nil {