Browse Source

Improve tproxy udp write back

世界 3 years ago
parent
commit
ca94a2ddcb
1 changed files with 32 additions and 16 deletions
  1. 32 16
      inbound/tproxy.go

+ 32 - 16
inbound/tproxy.go

@@ -12,6 +12,7 @@ import (
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/control"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
@@ -54,21 +55,17 @@ func (t *TProxy) Start() error {
 		return err
 	}
 	if t.tcpListener != nil {
-		tcpFd, err := common.GetFileDescriptor(t.tcpListener)
-		if err != nil {
-			return err
-		}
-		err = redir.TProxy(tcpFd, M.SocksaddrFromNet(t.tcpListener.Addr()).Addr.Is6())
+		err = control.Conn(t.tcpListener, func(fd uintptr) error {
+			return redir.TProxy(fd, M.SocksaddrFromNet(t.tcpListener.Addr()).Addr.Is6())
+		})
 		if err != nil {
 			return E.Cause(err, "configure tproxy TCP listener")
 		}
 	}
 	if t.udpConn != nil {
-		udpFd, err := common.GetFileDescriptor(t.udpConn)
-		if err != nil {
-			return err
-		}
-		err = redir.TProxy(udpFd, M.SocksaddrFromNet(t.udpConn.LocalAddr()).Addr.Is6())
+		err = control.Conn(t.udpConn, func(fd uintptr) error {
+			return redir.TProxy(fd, M.SocksaddrFromNet(t.udpConn.LocalAddr()).Addr.Is6())
+		})
 		if err != nil {
 			return E.Cause(err, "configure tproxy UDP listener")
 		}
@@ -88,21 +85,40 @@ func (t *TProxy) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.B
 	}
 	metadata.Destination = M.SocksaddrFromNetIP(destination)
 	t.udpNat.NewContextPacket(ctx, metadata.Source.AddrPort(), buffer, adapter.UpstreamMetadata(metadata), func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
-		return adapter.WithContext(log.ContextWithNewID(ctx), &metadata), &tproxyPacketWriter{natConn}
+		return adapter.WithContext(log.ContextWithNewID(ctx), &metadata), &tproxyPacketWriter{source: natConn}
 	})
 	return nil
 }
 
 type tproxyPacketWriter struct {
-	source N.PacketConn
+	source      N.PacketConn
+	destination M.Socksaddr
+	conn        *net.UDPConn
 }
 
 func (w *tproxyPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
 	defer buffer.Release()
-	udpConn, err := redir.DialUDP(destination.UDPAddr(), M.SocksaddrFromNet(w.source.LocalAddr()).UDPAddr())
-	if err != nil {
-		return E.Cause(err, "tproxy udp write back")
+	var udpConn *net.UDPConn
+	if w.destination == destination {
+		if w.conn != nil {
+			udpConn = w.conn
+		}
+	}
+	if udpConn == nil {
+		var err error
+		udpConn, err = redir.DialUDP(destination.UDPAddr(), M.SocksaddrFromNet(w.source.LocalAddr()).UDPAddr())
+		if err != nil {
+			return E.Cause(err, "tproxy udp write back")
+		}
+		if w.destination == destination {
+			w.conn = udpConn
+		} else {
+			defer udpConn.Close()
+		}
 	}
-	defer udpConn.Close()
 	return common.Error(udpConn.Write(buffer.Bytes()))
 }
+
+func (w *tproxyPacketWriter) Close() error {
+	return common.Close(common.PtrOrNil(w.conn))
+}