Browse Source

XHTTP server: Set remoteAddr & localAddr correctly

Completes https://github.com/XTLS/Xray-core/commit/22c50a70c61f18b54f9e9de82962a053261a398c
RPRX 10 months ago
parent
commit
8cb63db6c0

+ 2 - 1
common/net/system.go

@@ -76,8 +76,9 @@ type (
 )
 
 var (
-	ResolveUnixAddr = net.ResolveUnixAddr
+	ResolveTCPAddr  = net.ResolveTCPAddr
 	ResolveUDPAddr  = net.ResolveUDPAddr
+	ResolveUnixAddr = net.ResolveUnixAddr
 )
 
 type Resolver = net.Resolver

+ 19 - 16
proxy/proxy.go

@@ -113,12 +113,12 @@ type TrafficState struct {
 
 type InboundState struct {
 	// reader link state
-	WithinPaddingBuffers     bool
-	UplinkReaderDirectCopy   bool
-	RemainingCommand         int32
-	RemainingContent         int32
-	RemainingPadding         int32
-	CurrentCommand           int
+	WithinPaddingBuffers   bool
+	UplinkReaderDirectCopy bool
+	RemainingCommand       int32
+	RemainingContent       int32
+	RemainingPadding       int32
+	CurrentCommand         int
 	// write link state
 	IsPadding                bool
 	DownlinkWriterDirectCopy bool
@@ -133,19 +133,19 @@ type OutboundState struct {
 	RemainingPadding         int32
 	CurrentCommand           int
 	// write link state
-	IsPadding                bool
-	UplinkWriterDirectCopy   bool
+	IsPadding              bool
+	UplinkWriterDirectCopy bool
 }
 
 func NewTrafficState(userUUID []byte) *TrafficState {
 	return &TrafficState{
-		UserUUID:                 userUUID,
-		NumberOfPacketToFilter:   8,
-		EnableXtls:               false,
-		IsTLS12orAbove:           false,
-		IsTLS:                    false,
-		Cipher:                   0,
-		RemainingServerHello:     -1,
+		UserUUID:               userUUID,
+		NumberOfPacketToFilter: 8,
+		EnableXtls:             false,
+		IsTLS12orAbove:         false,
+		IsTLS:                  false,
+		Cipher:                 0,
+		RemainingServerHello:   -1,
 		Inbound: InboundState{
 			WithinPaddingBuffers:     true,
 			UplinkReaderDirectCopy:   false,
@@ -524,7 +524,7 @@ func XtlsFilterTls(buffer buf.MultiBuffer, trafficState *TrafficState, ctx conte
 	}
 }
 
-// UnwrapRawConn support unwrap stats, tls, utls, reality and proxyproto conn and get raw tcp conn from it
+// UnwrapRawConn support unwrap stats, tls, utls, reality, proxyproto, uds-wrapper conn and get raw tcp/uds conn from it
 func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
 	var readCounter, writerCounter stats.Counter
 	if conn != nil {
@@ -547,6 +547,9 @@ func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
 			conn = pc.Raw()
 			// 8192 > 4096, there is no need to process pc's bufReader
 		}
+		if uc, ok := conn.(*internet.UDSWrapperConn); ok {
+			conn = uc.Conn
+		}
 	}
 	return conn, readCounter, writerCounter
 }

+ 40 - 38
transport/internet/splithttp/hub.go

@@ -3,9 +3,8 @@ package splithttp
 import (
 	"bytes"
 	"context"
-	"crypto/tls"
+	gotls "crypto/tls"
 	"io"
-	gonet "net"
 	"net/http"
 	"net/url"
 	"strconv"
@@ -24,7 +23,7 @@ import (
 	"github.com/xtls/xray-core/transport/internet"
 	"github.com/xtls/xray-core/transport/internet/reality"
 	"github.com/xtls/xray-core/transport/internet/stat"
-	v2tls "github.com/xtls/xray-core/transport/internet/tls"
+	"github.com/xtls/xray-core/transport/internet/tls"
 	"golang.org/x/net/http2"
 	"golang.org/x/net/http2/h2c"
 )
@@ -36,7 +35,7 @@ type requestHandler struct {
 	ln        *Listener
 	sessionMu *sync.Mutex
 	sessions  sync.Map
-	localAddr gonet.TCPAddr
+	localAddr net.Addr
 }
 
 type httpSession struct {
@@ -144,14 +143,25 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 	}
 
 	forwardedAddrs := http_proto.ParseXForwardedFor(request.Header)
-	remoteAddr, err := gonet.ResolveTCPAddr("tcp", request.RemoteAddr)
+	var remoteAddr net.Addr
+	var err error
+	remoteAddr, err = net.ResolveTCPAddr("tcp", request.RemoteAddr)
 	if err != nil {
-		remoteAddr = &gonet.TCPAddr{}
+		remoteAddr = &net.TCPAddr{
+			IP:   []byte{0, 0, 0, 0},
+			Port: 0,
+		}
+	}
+	if request.ProtoMajor == 3 {
+		remoteAddr = &net.UDPAddr{
+			IP:   remoteAddr.(*net.TCPAddr).IP,
+			Port: remoteAddr.(*net.TCPAddr).Port,
+		}
 	}
 	if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() {
 		remoteAddr = &net.TCPAddr{
 			IP:   forwardedAddrs[0].IP(),
-			Port: int(0),
+			Port: 0,
 		}
 	}
 
@@ -289,6 +299,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 				responseFlusher: responseFlusher,
 			},
 			reader:     request.Body,
+			localAddr:  h.localAddr,
 			remoteAddr: remoteAddr,
 		}
 		if sessionId != "" { // if not stream-one
@@ -362,34 +373,30 @@ type Listener struct {
 	isH3       bool
 }
 
-func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
+func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
 	l := &Listener{
 		addConn: addConn,
 	}
-	shSettings := streamSettings.ProtocolSettings.(*Config)
-	l.config = shSettings
+	l.config = streamSettings.ProtocolSettings.(*Config)
 	if l.config != nil {
 		if streamSettings.SocketSettings == nil {
 			streamSettings.SocketSettings = &internet.SocketConfig{}
 		}
 	}
-	var listener net.Listener
-	var err error
-	var localAddr = gonet.TCPAddr{}
 	handler := &requestHandler{
-		config:    shSettings,
-		host:      shSettings.Host,
-		path:      shSettings.GetNormalizedPath(),
+		config:    l.config,
+		host:      l.config.Host,
+		path:      l.config.GetNormalizedPath(),
 		ln:        l,
 		sessionMu: &sync.Mutex{},
 		sessions:  sync.Map{},
-		localAddr: localAddr,
 	}
 	tlsConfig := getTLSConfig(streamSettings)
 	l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"
 
+	var err error
 	if port == net.Port(0) { // unix
-		listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
+		l.listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
 			Name: address.Domain(),
 			Net:  "unix",
 		}, streamSettings.SocketSettings)
@@ -405,13 +412,14 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
 		if err != nil {
 			return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err)
 		}
-		h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil)
+		l.h3listener, err = quic.ListenEarly(Conn, tlsConfig, nil)
 		if err != nil {
 			return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err)
 		}
-		l.h3listener = h3listener
 		errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port)
 
+		handler.localAddr = l.h3listener.Addr()
+
 		l.h3server = &http3.Server{
 			Handler: handler,
 		}
@@ -421,11 +429,7 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
 			}
 		}()
 	} else { // tcp
-		localAddr = gonet.TCPAddr{
-			IP:   address.IP(),
-			Port: int(port),
-		}
-		listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
+		l.listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 			IP:   address.IP(),
 			Port: int(port),
 		}, streamSettings.SocketSettings)
@@ -436,26 +440,24 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
 	}
 
 	// tcp/unix (h1/h2)
-	if listener != nil {
-		if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
+	if l.listener != nil {
+		if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
 			if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
-				listener = tls.NewListener(listener, tlsConfig)
+				l.listener = gotls.NewListener(l.listener, tlsConfig)
 			}
 		}
-
 		if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
-			listener = goreality.NewListener(listener, config.GetREALITYConfig())
+			l.listener = goreality.NewListener(l.listener, config.GetREALITYConfig())
 		}
 
+		handler.localAddr = l.listener.Addr()
+
 		// h2cHandler can handle both plaintext HTTP/1.1 and h2c
-		h2cHandler := h2c.NewHandler(handler, &http2.Server{})
-		l.listener = listener
 		l.server = http.Server{
-			Handler:           h2cHandler,
+			Handler:           h2c.NewHandler(handler, &http2.Server{}),
 			ReadHeaderTimeout: time.Second * 4,
 			MaxHeaderBytes:    8192,
 		}
-
 		go func() {
 			if err := l.server.Serve(l.listener); err != nil {
 				errors.LogWarningInner(ctx, err, "failed to serve HTTP for XHTTP")
@@ -488,13 +490,13 @@ func (ln *Listener) Close() error {
 	}
 	return errors.New("listener does not have an HTTP/3 server or a net.listener")
 }
-func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *tls.Config {
-	config := v2tls.ConfigFromStreamSettings(streamSettings)
+func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *gotls.Config {
+	config := tls.ConfigFromStreamSettings(streamSettings)
 	if config == nil {
-		return &tls.Config{}
+		return &gotls.Config{}
 	}
 	return config.GetTLSConfig()
 }
 func init() {
-	common.Must(internet.RegisterTransportListener(protocolName, ListenSH))
+	common.Must(internet.RegisterTransportListener(protocolName, ListenXH))
 }

+ 13 - 13
transport/internet/splithttp/splithttp_test.go

@@ -26,9 +26,9 @@ import (
 	"golang.org/x/net/http2"
 )
 
-func Test_listenSHAndDial(t *testing.T) {
+func Test_ListenXHAndDial(t *testing.T) {
 	listenPort := tcp.PickPort()
-	listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
+	listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
 		ProtocolName: "splithttp",
 		ProtocolSettings: &Config{
 			Path: "/sh",
@@ -85,7 +85,7 @@ func Test_listenSHAndDial(t *testing.T) {
 
 func TestDialWithRemoteAddr(t *testing.T) {
 	listenPort := tcp.PickPort()
-	listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
+	listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
 		ProtocolName: "splithttp",
 		ProtocolSettings: &Config{
 			Path: "sh",
@@ -125,7 +125,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
 	common.Must(listen.Close())
 }
 
-func Test_listenSHAndDial_TLS(t *testing.T) {
+func Test_ListenXHAndDial_TLS(t *testing.T) {
 	if runtime.GOARCH == "arm64" {
 		return
 	}
@@ -145,7 +145,7 @@ func Test_listenSHAndDial_TLS(t *testing.T) {
 			Certificate:   []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
 		},
 	}
-	listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
+	listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
 		go func() {
 			defer conn.Close()
 
@@ -180,7 +180,7 @@ func Test_listenSHAndDial_TLS(t *testing.T) {
 	}
 }
 
-func Test_listenSHAndDial_H2C(t *testing.T) {
+func Test_ListenXHAndDial_H2C(t *testing.T) {
 	if runtime.GOARCH == "arm64" {
 		return
 	}
@@ -193,7 +193,7 @@ func Test_listenSHAndDial_H2C(t *testing.T) {
 			Path: "shs",
 		},
 	}
-	listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
+	listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
 		go func() {
 			_ = conn.Close()
 		}()
@@ -227,7 +227,7 @@ func Test_listenSHAndDial_H2C(t *testing.T) {
 	}
 }
 
-func Test_listenSHAndDial_QUIC(t *testing.T) {
+func Test_ListenXHAndDial_QUIC(t *testing.T) {
 	if runtime.GOARCH == "arm64" {
 		return
 	}
@@ -250,7 +250,7 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
 	}
 
 	serverClosed := false
-	listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
+	listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
 		go func() {
 			defer conn.Close()
 
@@ -309,11 +309,11 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
 	}
 }
 
-func Test_listenSHAndDial_Unix(t *testing.T) {
+func Test_ListenXHAndDial_Unix(t *testing.T) {
 	tempDir := t.TempDir()
 	tempSocket := tempDir + "/server.sock"
 
-	listen, err := ListenSH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{
+	listen, err := ListenXH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{
 		ProtocolName: "splithttp",
 		ProtocolSettings: &Config{
 			Path: "/sh",
@@ -373,7 +373,7 @@ func Test_listenSHAndDial_Unix(t *testing.T) {
 
 func Test_queryString(t *testing.T) {
 	listenPort := tcp.PickPort()
-	listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
+	listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
 		ProtocolName: "splithttp",
 		ProtocolSettings: &Config{
 			// this querystring does not have any effect, but sometimes people blindly copy it from websocket config. make sure the outbound doesn't break
@@ -431,7 +431,7 @@ func Test_maxUpload(t *testing.T) {
 	}
 
 	var uploadSize int
-	listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
+	listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
 		go func(c stat.Connection) {
 			defer c.Close()
 			var b [10240]byte

+ 3 - 3
transport/internet/system_listener.go

@@ -54,7 +54,7 @@ func (l *listenUDSWrapper) Accept() (net.Conn, error) {
 	if err != nil {
 		return nil, err
 	}
-	return &listenUDSWrapperConn{Conn: conn}, nil
+	return &UDSWrapperConn{Conn: conn}, nil
 }
 
 func (l *listenUDSWrapper) Close() error {
@@ -65,11 +65,11 @@ func (l *listenUDSWrapper) Close() error {
 	return l.Listener.Close()
 }
 
-type listenUDSWrapperConn struct {
+type UDSWrapperConn struct {
 	net.Conn
 }
 
-func (conn *listenUDSWrapperConn) RemoteAddr() net.Addr {
+func (conn *UDSWrapperConn) RemoteAddr() net.Addr {
 	return &net.TCPAddr{
 		IP: []byte{0, 0, 0, 0},
 	}