Browse Source

VLESS Reverse Proxy: Transfer real Source & Local (IP & port), enabled by default

https://t.me/projectXtls/1039

https://github.com/XTLS/Xray-core/pull/5101#issuecomment-3404979909
RPRX 5 days ago
parent
commit
12f4a014e0

+ 1 - 1
app/proxyman/outbound/handler.go

@@ -108,7 +108,7 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou
 	}
 	h.proxyConfig = proxyConfig
 
-	ctx = session.ContextWithHandler(ctx, h)
+	ctx = session.ContextWithFullHandler(ctx, h)
 
 	rawProxyHandler, err := common.CreateObject(ctx, proxyConfig)
 	if err != nil {

+ 10 - 6
app/reverse/bridge.go

@@ -198,9 +198,11 @@ func (w *BridgeWorker) handleInternalConn(link *transport.Link) {
 
 func (w *BridgeWorker) Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error) {
 	if !isInternalDomain(dest) {
-		ctx = session.ContextWithInbound(ctx, &session.Inbound{
-			Tag: w.Tag,
-		})
+		if session.InboundFromContext(ctx) == nil {
+			ctx = session.ContextWithInbound(ctx, &session.Inbound{
+				Tag: w.Tag,
+			})
+		}
 		return w.Dispatcher.Dispatch(ctx, dest)
 	}
 
@@ -221,9 +223,11 @@ func (w *BridgeWorker) Dispatch(ctx context.Context, dest net.Destination) (*tra
 
 func (w *BridgeWorker) DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error {
 	if !isInternalDomain(dest) {
-		ctx = session.ContextWithInbound(ctx, &session.Inbound{
-			Tag: w.Tag,
-		})
+		if session.InboundFromContext(ctx) == nil {
+			ctx = session.ContextWithInbound(ctx, &session.Inbound{
+				Tag: w.Tag,
+			})
+		}
 		return w.Dispatcher.DispatchLink(ctx, dest, link)
 	}
 

+ 6 - 2
common/mux/client.go

@@ -264,7 +264,11 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
 		transferType = protocol.TransferTypePacket
 	}
 	s.transferType = transferType
-	writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
+	var inbound *session.Inbound
+	if session.IsReverseMuxFromContext(ctx) {
+		inbound = session.InboundFromContext(ctx)
+	}
+	writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx), inbound)
 	defer s.Close(false)
 	defer writer.Close()
 
@@ -384,7 +388,7 @@ func (m *ClientWorker) fetchOutput() {
 
 	var meta FrameMetadata
 	for {
-		err := meta.Unmarshal(reader)
+		err := meta.Unmarshal(reader, false)
 		if err != nil {
 			if errors.Cause(err) != io.EOF {
 				errors.LogInfoInner(context.Background(), err, "failed to read metadata")

+ 67 - 5
common/mux/frame.go

@@ -11,6 +11,7 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/serial"
+	"github.com/xtls/xray-core/common/session"
 )
 
 type SessionStatus byte
@@ -60,6 +61,7 @@ type FrameMetadata struct {
 	Option        bitmask.Byte
 	SessionStatus SessionStatus
 	GlobalID      [8]byte
+	Inbound       *session.Inbound
 }
 
 func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
@@ -79,11 +81,23 @@ func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
 		case net.Network_UDP:
 			common.Must(b.WriteByte(byte(TargetNetworkUDP)))
 		}
-
 		if err := addrParser.WriteAddressPort(b, f.Target.Address, f.Target.Port); err != nil {
 			return err
 		}
-		if b.UDP != nil { // make sure it's user's proxy request
+		if f.Inbound != nil {
+			if f.Inbound.Source.Network == net.Network_TCP || f.Inbound.Source.Network == net.Network_UDP {
+				common.Must(b.WriteByte(byte(f.Inbound.Source.Network - 1)))
+				if err := addrParser.WriteAddressPort(b, f.Inbound.Source.Address, f.Inbound.Source.Port); err != nil {
+					return err
+				}
+				if f.Inbound.Local.Network == net.Network_TCP || f.Inbound.Local.Network == net.Network_UDP {
+					common.Must(b.WriteByte(byte(f.Inbound.Local.Network - 1)))
+					if err := addrParser.WriteAddressPort(b, f.Inbound.Local.Address, f.Inbound.Local.Port); err != nil {
+						return err
+					}
+				}
+			}
+		} else if b.UDP != nil { // make sure it's user's proxy request
 			b.Write(f.GlobalID[:]) // no need to check whether it's empty
 		}
 	} else if b.UDP != nil {
@@ -97,7 +111,7 @@ func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
 }
 
 // Unmarshal reads FrameMetadata from the given reader.
-func (f *FrameMetadata) Unmarshal(reader io.Reader) error {
+func (f *FrameMetadata) Unmarshal(reader io.Reader, readSourceAndLocal bool) error {
 	metaLen, err := serial.ReadUint16(reader)
 	if err != nil {
 		return err
@@ -112,12 +126,12 @@ func (f *FrameMetadata) Unmarshal(reader io.Reader) error {
 	if _, err := b.ReadFullFrom(reader, int32(metaLen)); err != nil {
 		return err
 	}
-	return f.UnmarshalFromBuffer(b)
+	return f.UnmarshalFromBuffer(b, readSourceAndLocal)
 }
 
 // UnmarshalFromBuffer reads a FrameMetadata from the given buffer.
 // Visible for testing only.
-func (f *FrameMetadata) UnmarshalFromBuffer(b *buf.Buffer) error {
+func (f *FrameMetadata) UnmarshalFromBuffer(b *buf.Buffer, readSourceAndLocal bool) error {
 	if b.Len() < 4 {
 		return errors.New("insufficient buffer: ", b.Len())
 	}
@@ -150,6 +164,54 @@ func (f *FrameMetadata) UnmarshalFromBuffer(b *buf.Buffer) error {
 		}
 	}
 
+	if f.SessionStatus == SessionStatusNew && readSourceAndLocal {
+		f.Inbound = &session.Inbound{}
+
+		if b.Len() == 0 {
+			return nil // for heartbeat, etc.
+		}
+		network := TargetNetwork(b.Byte(0))
+		if network == 0 {
+			return nil // may be padding
+		}
+		b.Advance(1)
+		addr, port, err := addrParser.ReadAddressPort(nil, b)
+		if err != nil {
+			return errors.New("reading source: failed to parse address and port").Base(err)
+		}
+		switch network {
+		case TargetNetworkTCP:
+			f.Inbound.Source = net.TCPDestination(addr, port)
+		case TargetNetworkUDP:
+			f.Inbound.Source = net.UDPDestination(addr, port)
+		default:
+			return errors.New("reading source: unknown network type: ", network)
+		}
+
+		if b.Len() == 0 {
+			return nil
+		}
+		network = TargetNetwork(b.Byte(0))
+		if network == 0 {
+			return nil
+		}
+		b.Advance(1)
+		addr, port, err = addrParser.ReadAddressPort(nil, b)
+		if err != nil {
+			return errors.New("reading local: failed to parse address and port").Base(err)
+		}
+		switch network {
+		case TargetNetworkTCP:
+			f.Inbound.Local = net.TCPDestination(addr, port)
+		case TargetNetworkUDP:
+			f.Inbound.Local = net.UDPDestination(addr, port)
+		default:
+			return errors.New("reading local: unknown network type: ", network)
+		}
+
+		return nil
+	}
+
 	// Application data is essential, to test whether the pipe is closed.
 	if f.SessionStatus == SessionStatusNew && f.Option.Has(OptionData) &&
 		f.Target.Network == net.Network_UDP && b.Len() >= 8 {

+ 13 - 12
common/mux/mux_test.go

@@ -10,6 +10,7 @@ import (
 	. "github.com/xtls/xray-core/common/mux"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
+	"github.com/xtls/xray-core/common/session"
 	"github.com/xtls/xray-core/transport/pipe"
 )
 
@@ -32,13 +33,13 @@ func TestReaderWriter(t *testing.T) {
 	pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024))
 
 	dest := net.TCPDestination(net.DomainAddress("example.com"), 80)
-	writer := NewWriter(1, dest, pWriter, protocol.TransferTypeStream, [8]byte{})
+	writer := NewWriter(1, dest, pWriter, protocol.TransferTypeStream, [8]byte{}, &session.Inbound{})
 
 	dest2 := net.TCPDestination(net.LocalHostIP, 443)
-	writer2 := NewWriter(2, dest2, pWriter, protocol.TransferTypeStream, [8]byte{})
+	writer2 := NewWriter(2, dest2, pWriter, protocol.TransferTypeStream, [8]byte{}, &session.Inbound{})
 
 	dest3 := net.TCPDestination(net.LocalHostIPv6, 18374)
-	writer3 := NewWriter(3, dest3, pWriter, protocol.TransferTypeStream, [8]byte{})
+	writer3 := NewWriter(3, dest3, pWriter, protocol.TransferTypeStream, [8]byte{}, &session.Inbound{})
 
 	writePayload := func(writer *Writer, payload ...byte) error {
 		b := buf.New()
@@ -62,7 +63,7 @@ func TestReaderWriter(t *testing.T) {
 
 	{
 		var meta FrameMetadata
-		common.Must(meta.Unmarshal(bytesReader))
+		common.Must(meta.Unmarshal(bytesReader, false))
 		if r := cmp.Diff(meta, FrameMetadata{
 			SessionID:     1,
 			SessionStatus: SessionStatusNew,
@@ -81,7 +82,7 @@ func TestReaderWriter(t *testing.T) {
 
 	{
 		var meta FrameMetadata
-		common.Must(meta.Unmarshal(bytesReader))
+		common.Must(meta.Unmarshal(bytesReader, false))
 		if r := cmp.Diff(meta, FrameMetadata{
 			SessionStatus: SessionStatusNew,
 			SessionID:     2,
@@ -94,7 +95,7 @@ func TestReaderWriter(t *testing.T) {
 
 	{
 		var meta FrameMetadata
-		common.Must(meta.Unmarshal(bytesReader))
+		common.Must(meta.Unmarshal(bytesReader, false))
 		if r := cmp.Diff(meta, FrameMetadata{
 			SessionID:     1,
 			SessionStatus: SessionStatusKeep,
@@ -112,7 +113,7 @@ func TestReaderWriter(t *testing.T) {
 
 	{
 		var meta FrameMetadata
-		common.Must(meta.Unmarshal(bytesReader))
+		common.Must(meta.Unmarshal(bytesReader, false))
 		if r := cmp.Diff(meta, FrameMetadata{
 			SessionID:     3,
 			SessionStatus: SessionStatusNew,
@@ -131,7 +132,7 @@ func TestReaderWriter(t *testing.T) {
 
 	{
 		var meta FrameMetadata
-		common.Must(meta.Unmarshal(bytesReader))
+		common.Must(meta.Unmarshal(bytesReader, false))
 		if r := cmp.Diff(meta, FrameMetadata{
 			SessionID:     1,
 			SessionStatus: SessionStatusEnd,
@@ -143,7 +144,7 @@ func TestReaderWriter(t *testing.T) {
 
 	{
 		var meta FrameMetadata
-		common.Must(meta.Unmarshal(bytesReader))
+		common.Must(meta.Unmarshal(bytesReader, false))
 		if r := cmp.Diff(meta, FrameMetadata{
 			SessionID:     3,
 			SessionStatus: SessionStatusEnd,
@@ -155,7 +156,7 @@ func TestReaderWriter(t *testing.T) {
 
 	{
 		var meta FrameMetadata
-		common.Must(meta.Unmarshal(bytesReader))
+		common.Must(meta.Unmarshal(bytesReader, false))
 		if r := cmp.Diff(meta, FrameMetadata{
 			SessionID:     2,
 			SessionStatus: SessionStatusKeep,
@@ -173,7 +174,7 @@ func TestReaderWriter(t *testing.T) {
 
 	{
 		var meta FrameMetadata
-		common.Must(meta.Unmarshal(bytesReader))
+		common.Must(meta.Unmarshal(bytesReader, false))
 		if r := cmp.Diff(meta, FrameMetadata{
 			SessionID:     2,
 			SessionStatus: SessionStatusEnd,
@@ -187,7 +188,7 @@ func TestReaderWriter(t *testing.T) {
 
 	{
 		var meta FrameMetadata
-		err := meta.Unmarshal(bytesReader)
+		err := meta.Unmarshal(bytesReader, false)
 		if err == nil {
 			t.Error("nil error")
 		}

+ 10 - 2
common/mux/server.go

@@ -166,6 +166,14 @@ func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.Bu
 
 func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *buf.BufferedReader) error {
 	ctx = session.SubContextFromMuxInbound(ctx)
+	if meta.Inbound != nil && meta.Inbound.Source.IsValid() && meta.Inbound.Local.IsValid() {
+		if inbound := session.InboundFromContext(ctx); inbound != nil {
+			newInbound := *inbound
+			newInbound.Source = meta.Inbound.Source
+			newInbound.Local = meta.Inbound.Local
+			ctx = session.ContextWithInbound(ctx, &newInbound)
+		}
+	}
 	errors.LogInfo(ctx, "received request for ", meta.Target)
 	{
 		msg := &log.AccessMessage{
@@ -329,7 +337,7 @@ func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.Buffered
 
 func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedReader) error {
 	var meta FrameMetadata
-	err := meta.Unmarshal(reader)
+	err := meta.Unmarshal(reader, session.IsReverseMuxFromContext(ctx))
 	if err != nil {
 		return errors.New("failed to read metadata").Base(err)
 	}
@@ -340,7 +348,7 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead
 	case SessionStatusEnd:
 		err = w.handleStatusEnd(&meta, reader)
 	case SessionStatusNew:
-		err = w.handleStatusNew(ctx, &meta, reader)
+		err = w.handleStatusNew(session.ContextWithIsReverseMux(ctx, false), &meta, reader)
 	case SessionStatusKeep:
 		err = w.handleStatusKeep(&meta, reader)
 	default:

+ 5 - 1
common/mux/writer.go

@@ -6,6 +6,7 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/serial"
+	"github.com/xtls/xray-core/common/session"
 )
 
 type Writer struct {
@@ -16,9 +17,10 @@ type Writer struct {
 	hasError     bool
 	transferType protocol.TransferType
 	globalID     [8]byte
+	inbound      *session.Inbound
 }
 
-func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType protocol.TransferType, globalID [8]byte) *Writer {
+func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType protocol.TransferType, globalID [8]byte, inbound *session.Inbound) *Writer {
 	return &Writer{
 		id:           id,
 		dest:         dest,
@@ -26,6 +28,7 @@ func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType
 		followup:     false,
 		transferType: transferType,
 		globalID:     globalID,
+		inbound:      inbound,
 	}
 }
 
@@ -43,6 +46,7 @@ func (w *Writer) getNextFrameMeta() FrameMetadata {
 		SessionID: w.id,
 		Target:    w.dest,
 		GlobalID:  w.globalID,
+		Inbound:   w.inbound,
 	}
 
 	if w.followup {

+ 10 - 14
common/session/context.go

@@ -17,13 +17,13 @@ const (
 	inboundSessionKey         ctx.SessionKey = 1
 	outboundSessionKey        ctx.SessionKey = 2
 	contentSessionKey         ctx.SessionKey = 3
-	muxPreferredSessionKey    ctx.SessionKey = 4  // unused
+	isReverseMuxKey           ctx.SessionKey = 4  // is reverse mux
 	sockoptSessionKey         ctx.SessionKey = 5  // used by dokodemo to only receive sockopt.Mark
 	trackedConnectionErrorKey ctx.SessionKey = 6  // used by observer to get outbound error
 	dispatcherKey             ctx.SessionKey = 7  // used by ss2022 inbounds to get dispatcher
 	timeoutOnlyKey            ctx.SessionKey = 8  // mux context's child contexts to only cancel when its own traffic times out
 	allowedNetworkKey         ctx.SessionKey = 9  // muxcool server control incoming request tcp/udp
-	handlerSessionKey         ctx.SessionKey = 10 // outbound gets full handler
+	fullHandlerKey            ctx.SessionKey = 10 // outbound gets full handler
 	mitmAlpn11Key             ctx.SessionKey = 11 // used by TLS dialer
 	mitmServerNameKey         ctx.SessionKey = 12 // used by TLS dialer
 )
@@ -75,25 +75,21 @@ func ContentFromContext(ctx context.Context) *Content {
 	return nil
 }
 
-// ContextWithMuxPreferred returns a new context with the given bool
-func ContextWithMuxPreferred(ctx context.Context, forced bool) context.Context {
-	return context.WithValue(ctx, muxPreferredSessionKey, forced)
+func ContextWithIsReverseMux(ctx context.Context, isReverseMux bool) context.Context {
+	return context.WithValue(ctx, isReverseMuxKey, isReverseMux)
 }
 
-// MuxPreferredFromContext returns value in this context, or false if not contained.
-func MuxPreferredFromContext(ctx context.Context) bool {
-	if val, ok := ctx.Value(muxPreferredSessionKey).(bool); ok {
+func IsReverseMuxFromContext(ctx context.Context) bool {
+	if val, ok := ctx.Value(isReverseMuxKey).(bool); ok {
 		return val
 	}
 	return false
 }
 
-// ContextWithSockopt returns a new context with Socket configs included
 func ContextWithSockopt(ctx context.Context, s *Sockopt) context.Context {
 	return context.WithValue(ctx, sockoptSessionKey, s)
 }
 
-// SockoptFromContext returns Socket configs in this context, or nil if not contained.
 func SockoptFromContext(ctx context.Context) *Sockopt {
 	if sockopt, ok := ctx.Value(sockoptSessionKey).(*Sockopt); ok {
 		return sockopt
@@ -164,12 +160,12 @@ func AllowedNetworkFromContext(ctx context.Context) net.Network {
 	return net.Network_Unknown
 }
 
-func ContextWithHandler(ctx context.Context, handler outbound.Handler) context.Context {
-	return context.WithValue(ctx, handlerSessionKey, handler)
+func ContextWithFullHandler(ctx context.Context, handler outbound.Handler) context.Context {
+	return context.WithValue(ctx, fullHandlerKey, handler)
 }
 
-func HandlerFromContext(ctx context.Context) outbound.Handler {
-	if val, ok := ctx.Value(handlerSessionKey).(outbound.Handler); ok {
+func FullHandlerFromContext(ctx context.Context) outbound.Handler {
+	if val, ok := ctx.Value(fullHandlerKey).(outbound.Handler); ok {
 		return val
 	}
 	return nil

+ 1 - 1
proxy/vless/inbound/inbound.go

@@ -666,7 +666,7 @@ func (r *Reverse) Dispatch(ctx context.Context, link *transport.Link) {
 			link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
 			link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
 		}
-		r.client.Dispatch(ctx, link)
+		r.client.Dispatch(session.ContextWithIsReverseMux(ctx, true), link)
 	}
 }
 

+ 7 - 4
proxy/vless/outbound/outbound.go

@@ -89,8 +89,11 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
 		handler.reverse = &Reverse{
 			tag:        a.Reverse.Tag,
 			dispatcher: v.GetFeature(routing.DispatcherType()).(routing.Dispatcher),
-			ctx:        ctx,
-			handler:    handler,
+			ctx: session.ContextWithInbound(ctx, &session.Inbound{
+				Tag:  a.Reverse.Tag,
+				User: handler.server.User, // TODO: email
+			}),
+			handler: handler,
 		}
 		handler.reverse.monitorTask = &task.Periodic{
 			Execute:  handler.reverse.monitor,
@@ -397,7 +400,7 @@ func (r *Reverse) monitor() error {
 			Tag:        r.tag,
 			Dispatcher: r.dispatcher,
 		}
-		worker, err := mux.NewServerWorker(r.ctx, w, link1)
+		worker, err := mux.NewServerWorker(session.ContextWithIsReverseMux(r.ctx, true), w, link1)
 		if err != nil {
 			errors.LogWarningInner(r.ctx, err, "failed to create mux server worker")
 			return nil
@@ -408,7 +411,7 @@ func (r *Reverse) monitor() error {
 			ctx := session.ContextWithOutbounds(r.ctx, []*session.Outbound{{
 				Target: net.Destination{Address: net.DomainAddress("v1.rvs.cool")},
 			}})
-			r.handler.Process(ctx, link2, session.HandlerFromContext(ctx).(*proxyman.Handler))
+			r.handler.Process(ctx, link2, session.FullHandlerFromContext(ctx).(*proxyman.Handler))
 			common.Interrupt(reader1)
 			common.Interrupt(reader2)
 		}()