浏览代码

Fix WS reading X-Forwarded-For & Add tests (#3546)

Fixes https://github.com/XTLS/Xray-core/issues/3545

---------

Co-authored-by: mmmray <[email protected]>
风扇滑翔翼 1 年之前
父节点
当前提交
a7e198e1e2

+ 2 - 2
transport/internet/httpupgrade/httpupgrade_test.go

@@ -151,7 +151,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
 				return
 			}
 
-			_, err = c.Write([]byte("Response"))
+			_, err = c.Write([]byte(c.RemoteAddr().String()))
 			common.Must(err)
 		}(conn)
 	})
@@ -169,7 +169,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
 	var b [1024]byte
 	n, err := conn.Read(b[:])
 	common.Must(err)
-	if string(b[:n]) != "Response" {
+	if string(b[:n]) != "1.1.1.1:0" {
 		t.Error("response: ", string(b[:n]))
 	}
 

+ 2 - 2
transport/internet/splithttp/splithttp_test.go

@@ -96,7 +96,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
 				return
 			}
 
-			_, err = c.Write([]byte("Response"))
+			_, err = c.Write([]byte(c.RemoteAddr().String()))
 			common.Must(err)
 		}(conn)
 	})
@@ -113,7 +113,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
 
 	var b [1024]byte
 	n, _ := conn.Read(b[:])
-	if string(b[:n]) != "Response" {
+	if string(b[:n]) != "1.1.1.1:0" {
 		t.Error("response: ", string(b[:n]))
 	}
 

+ 9 - 5
transport/internet/websocket/connection.go

@@ -14,15 +14,19 @@ import (
 var _ buf.Writer = (*connection)(nil)
 
 // connection is a wrapper for net.Conn over WebSocket connection.
+// remoteAddr is used to pass "virtual" remote IP addresses in X-Forwarded-For.
+// so we shouldn't directly read it form conn.
 type connection struct {
-	conn   *websocket.Conn
-	reader io.Reader
+	conn       *websocket.Conn
+	reader     io.Reader
+	remoteAddr net.Addr
 }
 
 func NewConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection {
 	return &connection{
-		conn:   conn,
-		reader: extraReader,
+		conn:       conn,
+		remoteAddr: remoteAddr,
+		reader:     extraReader,
 	}
 }
 
@@ -90,7 +94,7 @@ func (c *connection) LocalAddr() net.Addr {
 }
 
 func (c *connection) RemoteAddr() net.Addr {
-	return c.conn.RemoteAddr()
+	return c.remoteAddr
 }
 
 func (c *connection) SetDeadline(t time.Time) error {

+ 2 - 2
transport/internet/websocket/ws_test.go

@@ -91,7 +91,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
 				return
 			}
 
-			_, err = c.Write([]byte("Response"))
+			_, err = c.Write([]byte(c.RemoteAddr().String()))
 			common.Must(err)
 		}(conn)
 	})
@@ -109,7 +109,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
 	var b [1024]byte
 	n, err := conn.Read(b[:])
 	common.Must(err)
-	if string(b[:n]) != "Response" {
+	if string(b[:n]) != "1.1.1.1:0" {
 		t.Error("response: ", string(b[:n]))
 	}