Browse Source

Fix connect domain for IP outbounds

世界 2 years ago
parent
commit
1402bdab41
3 changed files with 88 additions and 8 deletions
  1. 58 0
      outbound/default.go
  2. 15 4
      outbound/socks.go
  3. 15 4
      outbound/wireguard.go

+ 58 - 0
outbound/default.go

@@ -70,6 +70,28 @@ func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata a
 	return CopyEarlyConn(ctx, conn, outConn)
 }
 
+func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error {
+	ctx = adapter.WithContext(ctx, &metadata)
+	var outConn net.Conn
+	var err error
+	if len(metadata.DestinationAddresses) > 0 {
+		outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses)
+	} else if metadata.Destination.IsFqdn() {
+		var destinationAddresses []netip.Addr
+		destinationAddresses, err = router.LookupDefault(ctx, metadata.Destination.Fqdn)
+		if err != nil {
+			return N.HandshakeFailure(conn, err)
+		}
+		outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, destinationAddresses)
+	} else {
+		outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
+	}
+	if err != nil {
+		return N.HandshakeFailure(conn, err)
+	}
+	return CopyEarlyConn(ctx, conn, outConn)
+}
+
 func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error {
 	ctx = adapter.WithContext(ctx, &metadata)
 	var outConn net.PacketConn
@@ -99,6 +121,42 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn,
 	return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn))
 }
 
+func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error {
+	ctx = adapter.WithContext(ctx, &metadata)
+	var outConn net.PacketConn
+	var destinationAddress netip.Addr
+	var err error
+	if len(metadata.DestinationAddresses) > 0 {
+		outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
+	} else if metadata.Destination.IsFqdn() {
+		var destinationAddresses []netip.Addr
+		destinationAddresses, err = router.LookupDefault(ctx, metadata.Destination.Fqdn)
+		if err != nil {
+			return N.HandshakeFailure(conn, err)
+		}
+		outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses)
+	} else {
+		outConn, err = this.ListenPacket(ctx, metadata.Destination)
+	}
+	if err != nil {
+		return N.HandshakeFailure(conn, err)
+	}
+	if destinationAddress.IsValid() {
+		if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
+			natConn.UpdateDestination(destinationAddress)
+		}
+	}
+	switch metadata.Protocol {
+	case C.ProtocolSTUN:
+		ctx, conn = canceler.NewPacketConn(ctx, conn, C.STUNTimeout)
+	case C.ProtocolQUIC:
+		ctx, conn = canceler.NewPacketConn(ctx, conn, C.QUICTimeout)
+	case C.ProtocolDNS:
+		ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout)
+	}
+	return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn))
+}
+
 func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error {
 	if cachedReader, isCached := conn.(N.CachedReader); isCached {
 		payload := cachedReader.ReadCached()

+ 15 - 4
outbound/socks.go

@@ -80,11 +80,11 @@ func (h *Socks) DialContext(ctx context.Context, network string, destination M.S
 		return nil, E.Extend(N.ErrUnknownNetwork, network)
 	}
 	if h.resolve && destination.IsFqdn() {
-		addrs, err := h.router.LookupDefault(ctx, destination.Fqdn)
+		destinationAddresses, err := h.router.LookupDefault(ctx, destination.Fqdn)
 		if err != nil {
 			return nil, err
 		}
-		return N.DialSerial(ctx, h.client, network, destination, addrs)
+		return N.DialSerial(ctx, h.client, network, destination, destinationAddresses)
 	}
 	return h.client.DialContext(ctx, network, destination)
 }
@@ -97,14 +97,25 @@ func (h *Socks) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.
 		h.logger.InfoContext(ctx, "outbound UoT packet connection to ", destination)
 		return h.uotClient.ListenPacket(ctx, destination)
 	}
+	if h.resolve && destination.IsFqdn() {
+		destinationAddresses, err := h.router.LookupDefault(ctx, destination.Fqdn)
+		if err != nil {
+			return nil, err
+		}
+		packetConn, _, err := N.ListenSerial(ctx, h.client, destination, destinationAddresses)
+		if err != nil {
+			return nil, err
+		}
+		return packetConn, nil
+	}
 	h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
 	return h.client.ListenPacket(ctx, destination)
 }
 
 func (h *Socks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
-	return NewConnection(ctx, h, conn, metadata)
+	return NewDirectConnection(ctx, h.router, h, conn, metadata)
 }
 
 func (h *Socks) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
-	return NewPacketConnection(ctx, h, conn, metadata)
+	return NewDirectPacketConnection(ctx, h.router, h, conn, metadata)
 }

+ 15 - 4
outbound/wireguard.go

@@ -202,26 +202,37 @@ func (w *WireGuard) DialContext(ctx context.Context, network string, destination
 		w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
 	}
 	if destination.IsFqdn() {
-		addrs, err := w.router.LookupDefault(ctx, destination.Fqdn)
+		destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
 		if err != nil {
 			return nil, err
 		}
-		return N.DialSerial(ctx, w.tunDevice, network, destination, addrs)
+		return N.DialSerial(ctx, w.tunDevice, network, destination, destinationAddresses)
 	}
 	return w.tunDevice.DialContext(ctx, network, destination)
 }
 
 func (w *WireGuard) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
 	w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
+	if destination.IsFqdn() {
+		destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
+		if err != nil {
+			return nil, err
+		}
+		packetConn, _, err := N.ListenSerial(ctx, w.tunDevice, destination, destinationAddresses)
+		if err != nil {
+			return nil, err
+		}
+		return packetConn, err
+	}
 	return w.tunDevice.ListenPacket(ctx, destination)
 }
 
 func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
-	return NewConnection(ctx, w, conn, metadata)
+	return NewDirectConnection(ctx, w.router, w, conn, metadata)
 }
 
 func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
-	return NewPacketConnection(ctx, w, conn, metadata)
+	return NewDirectPacketConnection(ctx, w.router, w, conn, metadata)
 }
 
 func (w *WireGuard) Start() error {