ソースを参照

Add separate host config for websocket

yuhan6665 1 年間 前
コミット
7e3a8d3a04

+ 12 - 6
infra/conf/transport_internet.go

@@ -146,6 +146,7 @@ func (c *TCPConfig) Build() (proto.Message, error) {
 }
 
 type WebSocketConfig struct {
+	Host                string            `json:"host"`
 	Path                string            `json:"path"`
 	Headers             map[string]string `json:"headers"`
 	AcceptProxyProtocol bool              `json:"acceptProxyProtocol"`
@@ -154,10 +155,6 @@ type WebSocketConfig struct {
 // Build implements Buildable.
 func (c *WebSocketConfig) Build() (proto.Message, error) {
 	path := c.Path
-	header := make(map[string]string);
-	for key, value := range c.Headers {
-		header[key] = value;
-	}
 	var ed uint32
 	if u, err := url.Parse(path); err == nil {
 		if q := u.Query(); q.Get("ed") != "" {
@@ -168,9 +165,18 @@ func (c *WebSocketConfig) Build() (proto.Message, error) {
 			path = u.String()
 		}
 	}
+	// If http host is not set in the Host field, but in headers field, we add it to Host Field here.
+	// If we don't do that, http host will be overwritten as address.
+	// Host priority: Host field > headers field > address.
+	if c.Host == "" && c.Headers["host"] != "" {
+		c.Host = c.Headers["host"]
+	} else if c.Host == "" && c.Headers["Host"] != "" {
+		c.Host = c.Headers["Host"]
+	}
 	config := &websocket.Config{
 		Path:                path,
-		Header:              header,
+		Host:                c.Host,
+		Header:              c.Headers,
 		AcceptProxyProtocol: c.AcceptProxyProtocol,
 		Ed:                  ed,
 	}
@@ -178,8 +184,8 @@ func (c *WebSocketConfig) Build() (proto.Message, error) {
 }
 
 type HttpUpgradeConfig struct {
-	Path                string            `json:"path"`
 	Host                string            `json:"host"`
+	Path                string            `json:"path"`
 	Headers             map[string]string `json:"headers"`
 	AcceptProxyProtocol bool              `json:"acceptProxyProtocol"`
 }

+ 1 - 0
transport/internet/websocket/config.go

@@ -25,6 +25,7 @@ func (c *Config) GetRequestHeader() http.Header {
 	for k, v := range c.Header {
 		header.Add(k, v)
 	}
+	header.Set("Host", c.Host)
 	return header
 }
 

+ 6 - 0
transport/internet/websocket/hub.go

@@ -21,6 +21,7 @@ import (
 )
 
 type requestHandler struct {
+	host string
 	path string
 	ln   *Listener
 }
@@ -37,6 +38,10 @@ var upgrader = &websocket.Upgrader{
 }
 
 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
+	if len(h.host) > 0 && request.Host != h.host {
+		writer.WriteHeader(http.StatusNotFound)
+		return
+	}
 	if request.URL.Path != h.path {
 		writer.WriteHeader(http.StatusNotFound)
 		return
@@ -125,6 +130,7 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet
 
 	l.server = http.Server{
 		Handler: &requestHandler{
+			host: wsSettings.Host,
 			path: wsSettings.GetNormalizedPath(),
 			ln:   l,
 		},