浏览代码

Fix UDP domain NAT

世界 2 年之前
父节点
当前提交
cb2e15f8a7
共有 11 个文件被更改,包括 138 次插入13 次删除
  1. 8 0
      adapter/inbound.go
  2. 2 2
      common/dialer/resolve.go
  3. 9 2
      outbound/default.go
  4. 1 1
      outbound/direct.go
  5. 1 1
      route/router.go
  6. 27 0
      test/box_test.go
  7. 83 0
      test/domain_inbound_test.go
  8. 3 3
      test/go.mod
  9. 2 2
      test/go.sum
  10. 1 1
      test/hysteria2_test.go
  11. 1 1
      test/wireguard_test.go

+ 8 - 0
adapter/inbound.go

@@ -75,3 +75,11 @@ func AppendContext(ctx context.Context) (context.Context, *InboundContext) {
 	metadata = new(InboundContext)
 	return WithContext(ctx, metadata), metadata
 }
+
+func ExtendContext(ctx context.Context) (context.Context, *InboundContext) {
+	var newMetadata InboundContext
+	if metadata := ContextFrom(ctx); metadata != nil {
+		newMetadata = *metadata
+	}
+	return WithContext(ctx, &newMetadata), &newMetadata
+}

+ 2 - 2
common/dialer/resolve.go

@@ -36,7 +36,7 @@ func (d *ResolveDialer) DialContext(ctx context.Context, network string, destina
 	if !destination.IsFqdn() {
 		return d.dialer.DialContext(ctx, network, destination)
 	}
-	ctx, metadata := adapter.AppendContext(ctx)
+	ctx, metadata := adapter.ExtendContext(ctx)
 	ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
 	metadata.Destination = destination
 	metadata.Domain = ""
@@ -61,7 +61,7 @@ func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
 	if !destination.IsFqdn() {
 		return d.dialer.ListenPacket(ctx, destination)
 	}
-	ctx, metadata := adapter.AppendContext(ctx)
+	ctx, metadata := adapter.ExtendContext(ctx)
 	ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
 	metadata.Destination = destination
 	metadata.Domain = ""

+ 9 - 2
outbound/default.go

@@ -17,6 +17,7 @@ import (
 	"github.com/sagernet/sing/common/bufio"
 	"github.com/sagernet/sing/common/canceler"
 	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 )
 
@@ -119,7 +120,10 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn,
 		return err
 	}
 	if destinationAddress.IsValid() {
-		if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
+		if metadata.Destination.IsFqdn() {
+			outConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
+		}
+		if natConn, loaded := common.Cast[*bufio.NATPacketConn](conn); loaded {
 			natConn.UpdateDestination(destinationAddress)
 		}
 	}
@@ -159,7 +163,10 @@ func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this
 		return err
 	}
 	if destinationAddress.IsValid() {
-		if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
+		if metadata.Destination.IsFqdn() {
+			outConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination)
+		}
+		if natConn, loaded := common.Cast[*bufio.NATPacketConn](conn); loaded {
 			natConn.UpdateDestination(destinationAddress)
 		}
 	}

+ 1 - 1
outbound/direct.go

@@ -164,7 +164,7 @@ func (h *Direct) DialParallel(ctx context.Context, network string, destination M
 }
 
 func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
-	ctx, metadata := adapter.AppendContext(ctx)
+	ctx, metadata := adapter.ExtendContext(ctx)
 	metadata.Outbound = h.tag
 	metadata.Destination = destination
 	switch h.overrideOption {

+ 1 - 1
route/router.go

@@ -835,7 +835,7 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
 		}
 	}
 	if metadata.FakeIP {
-		conn = fakeip.NewNATPacketConn(conn, metadata.OriginDestination, metadata.Destination)
+		conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination)
 	}
 	return detour.NewPacketConnection(ctx, conn, metadata)
 }

+ 27 - 0
test/box_test.go

@@ -2,10 +2,15 @@ package main
 
 import (
 	"context"
+	"crypto/tls"
+	"io"
 	"net"
+	"net/http"
 	"testing"
 	"time"
 
+	"github.com/sagernet/quic-go"
+	"github.com/sagernet/quic-go/http3"
 	"github.com/sagernet/sing-box"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/option"
@@ -74,6 +79,28 @@ func testSuit(t *testing.T, clientPort uint16, testPort uint16) {
 	// require.NoError(t, testPacketConnTimeout(t, dialUDP))
 }
 
+func testQUIC(t *testing.T, clientPort uint16) {
+	dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "")
+	client := &http.Client{
+		Transport: &http3.RoundTripper{
+			Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
+				destination := M.ParseSocksaddr(addr)
+				udpConn, err := dialer.DialContext(ctx, N.NetworkUDP, destination)
+				if err != nil {
+					return nil, err
+				}
+				return quic.DialEarly(ctx, udpConn.(net.PacketConn), destination, tlsCfg, cfg)
+			},
+		},
+	}
+	response, err := client.Get("https://cloudflare.com/cdn-cgi/trace")
+	require.NoError(t, err)
+	require.Equal(t, http.StatusOK, response.StatusCode)
+	content, err := io.ReadAll(response.Body)
+	require.NoError(t, err)
+	println(string(content))
+}
+
 func testSuitLargeUDP(t *testing.T, clientPort uint16, testPort uint16) {
 	dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "")
 	dialTCP := func() (net.Conn, error) {

+ 83 - 0
test/domain_inbound_test.go

@@ -0,0 +1,83 @@
+package main
+
+import (
+	"net/netip"
+	"testing"
+
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/option"
+	dns "github.com/sagernet/sing-dns"
+
+	"github.com/gofrs/uuid/v5"
+)
+
+func TestTUICDomainUDP(t *testing.T) {
+	_, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
+	startInstance(t, option.Options{
+		Inbounds: []option.Inbound{
+			{
+				Type: C.TypeMixed,
+				Tag:  "mixed-in",
+				MixedOptions: option.HTTPMixedInboundOptions{
+					ListenOptions: option.ListenOptions{
+						Listen:     option.NewListenAddress(netip.IPv4Unspecified()),
+						ListenPort: clientPort,
+					},
+				},
+			},
+			{
+				Type: C.TypeTUIC,
+				TUICOptions: option.TUICInboundOptions{
+					ListenOptions: option.ListenOptions{
+						Listen:     option.NewListenAddress(netip.IPv4Unspecified()),
+						ListenPort: serverPort,
+						InboundOptions: option.InboundOptions{
+							DomainStrategy: option.DomainStrategy(dns.DomainStrategyUseIPv6),
+						},
+					},
+					Users: []option.TUICUser{{
+						UUID: uuid.Nil.String(),
+					}},
+					TLS: &option.InboundTLSOptions{
+						Enabled:         true,
+						ServerName:      "example.org",
+						CertificatePath: certPem,
+						KeyPath:         keyPem,
+					},
+				},
+			},
+		},
+		Outbounds: []option.Outbound{
+			{
+				Type: C.TypeDirect,
+			},
+			{
+				Type: C.TypeTUIC,
+				Tag:  "tuic-out",
+				TUICOptions: option.TUICOutboundOptions{
+					ServerOptions: option.ServerOptions{
+						Server:     "127.0.0.1",
+						ServerPort: serverPort,
+					},
+					UUID: uuid.Nil.String(),
+					TLS: &option.OutboundTLSOptions{
+						Enabled:         true,
+						ServerName:      "example.org",
+						CertificatePath: certPem,
+					},
+				},
+			},
+		},
+		Route: &option.RouteOptions{
+			Rules: []option.Rule{
+				{
+					DefaultOptions: option.DefaultRule{
+						Inbound:  []string{"mixed-in"},
+						Outbound: "tuic-out",
+					},
+				},
+			},
+		},
+	})
+	testQUIC(t, clientPort)
+}

+ 3 - 3
test/go.mod

@@ -10,7 +10,9 @@ require (
 	github.com/docker/docker v24.0.6+incompatible
 	github.com/docker/go-connections v0.4.0
 	github.com/gofrs/uuid/v5 v5.0.0
+	github.com/sagernet/quic-go v0.0.0-20230919101909-0cc6c5dcecee
 	github.com/sagernet/sing v0.2.15
+	github.com/sagernet/sing-dns v0.1.10
 	github.com/sagernet/sing-quic v0.1.2
 	github.com/sagernet/sing-shadowsocks v0.2.5
 	github.com/sagernet/sing-shadowsocks2 v0.1.4
@@ -73,12 +75,10 @@ require (
 	github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 // indirect
 	github.com/sagernet/gvisor v0.0.0-20230627031050-1ab0276e0dd2 // indirect
 	github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 // indirect
-	github.com/sagernet/quic-go v0.0.0-20230919101909-0cc6c5dcecee // indirect
 	github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 // indirect
-	github.com/sagernet/sing-dns v0.1.10 // indirect
 	github.com/sagernet/sing-mux v0.1.3 // indirect
 	github.com/sagernet/sing-shadowtls v0.1.4 // indirect
-	github.com/sagernet/sing-tun v0.1.15 // indirect
+	github.com/sagernet/sing-tun v0.1.16 // indirect
 	github.com/sagernet/sing-vmess v0.1.8 // indirect
 	github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 // indirect
 	github.com/sagernet/tfo-go v0.0.0-20230816093905-5a5c285d44a6 // indirect

+ 2 - 2
test/go.sum

@@ -147,8 +147,8 @@ github.com/sagernet/sing-shadowsocks2 v0.1.4 h1:vht2M8t3m5DTgXR2j24KbYOygG5aOp+M
 github.com/sagernet/sing-shadowsocks2 v0.1.4/go.mod h1:Mgdee99NxxNd5Zld3ixIs18yVs4x2dI2VTDDE1N14Wc=
 github.com/sagernet/sing-shadowtls v0.1.4 h1:aTgBSJEgnumzFenPvc+kbD9/W0PywzWevnVpEx6Tw3k=
 github.com/sagernet/sing-shadowtls v0.1.4/go.mod h1:F8NBgsY5YN2beQavdgdm1DPlhaKQlaL6lpDdcBglGK4=
-github.com/sagernet/sing-tun v0.1.15 h1:XfHQD/dhCCQeespPojB4gRhADI1A/4mSLLJCnh5qUnQ=
-github.com/sagernet/sing-tun v0.1.15/go.mod h1:zgRoBAtOM24QXx0IKYFEnuTtXPq1Z4rDYRWkP8kJm+g=
+github.com/sagernet/sing-tun v0.1.16 h1:RHXYIVg6uacvdfbYMiPEz9VX5uu6mNrvP7u9yAH3oNc=
+github.com/sagernet/sing-tun v0.1.16/go.mod h1:S3q8GCjeyRniK+KLmo4XqKY0bS3x2UdKkKbqxT/Agl8=
 github.com/sagernet/sing-vmess v0.1.8 h1:XVWad1RpTy9b5tPxdm5MCU8cGfrTGdR8qCq6HV2aCNc=
 github.com/sagernet/sing-vmess v0.1.8/go.mod h1:vhx32UNzTDUkNwOyIjcZQohre1CaytquC5mPplId8uA=
 github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as=

+ 1 - 1
test/hysteria2_test.go

@@ -97,7 +97,7 @@ func testHysteria2Self(t *testing.T, salamanderPassword string) {
 			},
 		},
 	})
-	testSuit(t, clientPort, testPort)
+	testSuitLargeUDP(t, clientPort, testPort)
 }
 
 func TestHysteria2Inbound(t *testing.T) {

+ 1 - 1
test/wireguard_test.go

@@ -40,7 +40,7 @@ func _TestWireGuard(t *testing.T) {
 						Server:     "127.0.0.1",
 						ServerPort: serverPort,
 					},
-					LocalAddress:  []option.ListenPrefix{option.ListenPrefix(netip.MustParsePrefix("10.0.0.2/32"))},
+					LocalAddress:  []netip.Prefix{netip.MustParsePrefix("10.0.0.2/32")},
 					PrivateKey:    "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=",
 					PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=",
 				},