浏览代码

Fix override packet conn

世界 2 年之前
父节点
当前提交
53f19a6ead
共有 1 个文件被更改,包括 43 次插入2 次删除
  1. 43 2
      outbound/direct.go

+ 43 - 2
outbound/direct.go

@@ -12,6 +12,8 @@ import (
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing-dns"
+	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
@@ -160,8 +162,30 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
 	ctx, metadata := adapter.AppendContext(ctx)
 	metadata.Outbound = h.tag
 	metadata.Destination = destination
-	h.logger.InfoContext(ctx, "outbound packet connection")
-	return h.dialer.ListenPacket(ctx, destination)
+	switch h.overrideOption {
+	case 1:
+		destination = h.overrideDestination
+	case 2:
+		newDestination := h.overrideDestination
+		newDestination.Port = destination.Port
+		destination = newDestination
+	case 3:
+		destination.Port = h.overrideDestination.Port
+	}
+	if h.overrideOption == 0 {
+		h.logger.InfoContext(ctx, "outbound packet connection")
+	} else {
+		h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
+	}
+	conn, err := h.dialer.ListenPacket(ctx, destination)
+	if err != nil {
+		return nil, err
+	}
+	if h.overrideOption == 0 {
+		return conn, nil
+	} else {
+		return &overridePacketConn{bufio.NewPacketConn(conn), destination}, nil
+	}
 }
 
 func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
@@ -171,3 +195,20 @@ func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adap
 func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
 	return NewPacketConnection(ctx, h, conn, metadata)
 }
+
+type overridePacketConn struct {
+	N.NetPacketConn
+	overrideDestination M.Socksaddr
+}
+
+func (c *overridePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
+	return c.NetPacketConn.WritePacket(buffer, c.overrideDestination)
+}
+
+func (c *overridePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
+	return c.NetPacketConn.WriteTo(p, c.overrideDestination.UDPAddr())
+}
+
+func (c *overridePacketConn) Upstream() any {
+	return c.NetPacketConn
+}