Pārlūkot izejas kodu

Fix loopback detector

世界 1 gadu atpakaļ
vecāks
revīzija
7b0f5061dc
2 mainītis faili ar 52 papildinājumiem un 27 dzēšanām
  1. 3 3
      outbound/direct.go
  2. 49 24
      outbound/direct_loopback_detect.go

+ 3 - 3
outbound/direct.go

@@ -148,7 +148,7 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
 	if err != nil {
 		return nil, err
 	}
-	conn = h.loopBack.NewPacketConn(bufio.NewPacketConn(conn))
+	conn = h.loopBack.NewPacketConn(bufio.NewPacketConn(conn), destination)
 	if originDestination != destination {
 		conn = bufio.NewNATPacketConn(bufio.NewPacketConn(conn), destination, originDestination)
 	}
@@ -156,14 +156,14 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
 }
 
 func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
-	if h.loopBack.CheckConn(metadata.Source.AddrPort()) {
+	if h.loopBack.CheckConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) {
 		return E.New("reject loopback connection to ", metadata.Destination)
 	}
 	return NewConnection(ctx, h, conn, metadata)
 }
 
 func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
-	if h.loopBack.CheckPacketConn(metadata.Source.AddrPort()) {
+	if h.loopBack.CheckPacketConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) {
 		return E.New("reject loopback packet connection to ", metadata.Destination)
 	}
 	return NewPacketConnection(ctx, h, conn, metadata)

+ 49 - 24
outbound/direct_loopback_detect.go

@@ -10,58 +10,83 @@ import (
 )
 
 type loopBackDetector struct {
+	// router           adapter.Router
 	connAccess       sync.RWMutex
 	packetConnAccess sync.RWMutex
-	connMap          map[netip.AddrPort]bool
-	packetConnMap    map[netip.AddrPort]bool
+	connMap          map[netip.AddrPort]netip.AddrPort
+	packetConnMap    map[uint16]uint16
 }
 
-func newLoopBackDetector() *loopBackDetector {
+func newLoopBackDetector( /*router adapter.Router*/ ) *loopBackDetector {
 	return &loopBackDetector{
-		connMap:       make(map[netip.AddrPort]bool),
-		packetConnMap: make(map[netip.AddrPort]bool),
+		// router:        router,
+		connMap:       make(map[netip.AddrPort]netip.AddrPort),
+		packetConnMap: make(map[uint16]uint16),
 	}
 }
 
 func (l *loopBackDetector) NewConn(conn net.Conn) net.Conn {
-	connAddr := M.AddrPortFromNet(conn.LocalAddr())
-	if !connAddr.IsValid() {
+	source := M.AddrPortFromNet(conn.LocalAddr())
+	if !source.IsValid() {
 		return conn
 	}
 	if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn {
+		/*if !source.Addr().IsLoopback() {
+			_, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr())
+			if err != nil {
+				return conn
+			}
+		}*/
+		if !N.IsPublicAddr(source.Addr()) {
+			return conn
+		}
 		l.packetConnAccess.Lock()
-		l.packetConnMap[connAddr] = true
+		l.packetConnMap[source.Port()] = M.AddrPortFromNet(conn.RemoteAddr()).Port()
 		l.packetConnAccess.Unlock()
-		return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connAddr: connAddr}
+		return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connPort: source.Port()}
 	} else {
 		l.connAccess.Lock()
-		l.connMap[connAddr] = true
+		l.connMap[source] = M.AddrPortFromNet(conn.RemoteAddr())
 		l.connAccess.Unlock()
-		return &loopBackDetectWrapper{Conn: conn, detector: l, connAddr: connAddr}
+		return &loopBackDetectWrapper{Conn: conn, detector: l, connAddr: source}
 	}
 }
 
-func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn) N.NetPacketConn {
-	connAddr := M.AddrPortFromNet(conn.LocalAddr())
-	if !connAddr.IsValid() {
+func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn, destination M.Socksaddr) N.NetPacketConn {
+	source := M.AddrPortFromNet(conn.LocalAddr())
+	if !source.IsValid() {
 		return conn
 	}
 	l.packetConnAccess.Lock()
-	l.packetConnMap[connAddr] = true
+	l.packetConnMap[source.Port()] = destination.AddrPort().Port()
 	l.packetConnAccess.Unlock()
-	return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connAddr: connAddr}
+	return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connPort: source.Port()}
 }
 
-func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool {
+func (l *loopBackDetector) CheckConn(source netip.AddrPort, local netip.AddrPort) bool {
 	l.connAccess.RLock()
 	defer l.connAccess.RUnlock()
-	return l.connMap[connAddr]
+	destination, loaded := l.connMap[source]
+	return loaded && destination != local
 }
 
-func (l *loopBackDetector) CheckPacketConn(connAddr netip.AddrPort) bool {
+func (l *loopBackDetector) CheckPacketConn(source netip.AddrPort, local netip.AddrPort) bool {
+	if !source.IsValid() {
+		return false
+	}
+	/*if !source.Addr().IsLoopback() {
+		_, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr())
+		if err != nil {
+			return false
+		}
+	}*/
+	if N.IsPublicAddr(source.Addr()) {
+		return false
+	}
 	l.packetConnAccess.RLock()
 	defer l.packetConnAccess.RUnlock()
-	return l.packetConnMap[connAddr]
+	destinationPort, loaded := l.packetConnMap[source.Port()]
+	return loaded && destinationPort != local.Port()
 }
 
 type loopBackDetectWrapper struct {
@@ -95,14 +120,14 @@ func (w *loopBackDetectWrapper) Upstream() any {
 type loopBackDetectPacketWrapper struct {
 	N.NetPacketConn
 	detector  *loopBackDetector
-	connAddr  netip.AddrPort
+	connPort  uint16
 	closeOnce sync.Once
 }
 
 func (w *loopBackDetectPacketWrapper) Close() error {
 	w.closeOnce.Do(func() {
 		w.detector.packetConnAccess.Lock()
-		delete(w.detector.packetConnMap, w.connAddr)
+		delete(w.detector.packetConnMap, w.connPort)
 		w.detector.packetConnAccess.Unlock()
 	})
 	return w.NetPacketConn.Close()
@@ -128,14 +153,14 @@ type abstractUDPConn interface {
 type loopBackDetectUDPWrapper struct {
 	abstractUDPConn
 	detector  *loopBackDetector
-	connAddr  netip.AddrPort
+	connPort  uint16
 	closeOnce sync.Once
 }
 
 func (w *loopBackDetectUDPWrapper) Close() error {
 	w.closeOnce.Do(func() {
 		w.detector.packetConnAccess.Lock()
-		delete(w.detector.packetConnMap, w.connAddr)
+		delete(w.detector.packetConnMap, w.connPort)
 		w.detector.packetConnAccess.Unlock()
 	})
 	return w.abstractUDPConn.Close()