|  | @@ -5,7 +5,9 @@ package wireguard
 | 
	
		
			
				|  |  |  import (
 | 
	
		
			
				|  |  |  	"context"
 | 
	
		
			
				|  |  |  	"net"
 | 
	
		
			
				|  |  | +	"net/netip"
 | 
	
		
			
				|  |  |  	"os"
 | 
	
		
			
				|  |  | +	"time"
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	"github.com/sagernet/gvisor/pkg/buffer"
 | 
	
		
			
				|  |  |  	"github.com/sagernet/gvisor/pkg/tcpip"
 | 
	
	
		
			
				|  | @@ -14,9 +16,14 @@ import (
 | 
	
		
			
				|  |  |  	"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
 | 
	
		
			
				|  |  |  	"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
 | 
	
		
			
				|  |  |  	"github.com/sagernet/gvisor/pkg/tcpip/stack"
 | 
	
		
			
				|  |  | +	"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
 | 
	
		
			
				|  |  |  	"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
 | 
	
		
			
				|  |  |  	"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
 | 
	
		
			
				|  |  | +	"github.com/sagernet/sing-box/adapter"
 | 
	
		
			
				|  |  | +	"github.com/sagernet/sing-box/log"
 | 
	
		
			
				|  |  |  	"github.com/sagernet/sing-tun"
 | 
	
		
			
				|  |  | +	"github.com/sagernet/sing-tun/ping"
 | 
	
		
			
				|  |  | +	"github.com/sagernet/sing/common/buf"
 | 
	
		
			
				|  |  |  	E "github.com/sagernet/sing/common/exceptions"
 | 
	
		
			
				|  |  |  	M "github.com/sagernet/sing/common/metadata"
 | 
	
		
			
				|  |  |  	N "github.com/sagernet/sing/common/network"
 | 
	
	
		
			
				|  | @@ -24,30 +31,40 @@ import (
 | 
	
		
			
				|  |  |  	wgTun "github.com/sagernet/wireguard-go/tun"
 | 
	
		
			
				|  |  |  )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -var _ Device = (*stackDevice)(nil)
 | 
	
		
			
				|  |  | +var _ NatDevice = (*stackDevice)(nil)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  type stackDevice struct {
 | 
	
		
			
				|  |  | -	stack      *stack.Stack
 | 
	
		
			
				|  |  | -	mtu        uint32
 | 
	
		
			
				|  |  | -	events     chan wgTun.Event
 | 
	
		
			
				|  |  | -	outbound   chan *stack.PacketBuffer
 | 
	
		
			
				|  |  | -	done       chan struct{}
 | 
	
		
			
				|  |  | -	dispatcher stack.NetworkDispatcher
 | 
	
		
			
				|  |  | -	addr4      tcpip.Address
 | 
	
		
			
				|  |  | -	addr6      tcpip.Address
 | 
	
		
			
				|  |  | +	ctx            context.Context
 | 
	
		
			
				|  |  | +	logger         log.ContextLogger
 | 
	
		
			
				|  |  | +	stack          *stack.Stack
 | 
	
		
			
				|  |  | +	mtu            uint32
 | 
	
		
			
				|  |  | +	events         chan wgTun.Event
 | 
	
		
			
				|  |  | +	outbound       chan *stack.PacketBuffer
 | 
	
		
			
				|  |  | +	packetOutbound chan *buf.Buffer
 | 
	
		
			
				|  |  | +	done           chan struct{}
 | 
	
		
			
				|  |  | +	dispatcher     stack.NetworkDispatcher
 | 
	
		
			
				|  |  | +	inet4Address   netip.Addr
 | 
	
		
			
				|  |  | +	inet6Address   netip.Addr
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  func newStackDevice(options DeviceOptions) (*stackDevice, error) {
 | 
	
		
			
				|  |  |  	tunDevice := &stackDevice{
 | 
	
		
			
				|  |  | -		mtu:      options.MTU,
 | 
	
		
			
				|  |  | -		events:   make(chan wgTun.Event, 1),
 | 
	
		
			
				|  |  | -		outbound: make(chan *stack.PacketBuffer, 256),
 | 
	
		
			
				|  |  | -		done:     make(chan struct{}),
 | 
	
		
			
				|  |  | +		ctx:            options.Context,
 | 
	
		
			
				|  |  | +		logger:         options.Logger,
 | 
	
		
			
				|  |  | +		mtu:            options.MTU,
 | 
	
		
			
				|  |  | +		events:         make(chan wgTun.Event, 1),
 | 
	
		
			
				|  |  | +		outbound:       make(chan *stack.PacketBuffer, 256),
 | 
	
		
			
				|  |  | +		packetOutbound: make(chan *buf.Buffer, 256),
 | 
	
		
			
				|  |  | +		done:           make(chan struct{}),
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
 | 
	
		
			
				|  |  | +	ipStack, err := tun.NewGVisorStackWithOptions((*wireEndpoint)(tunDevice), stack.NICOptions{}, true)
 | 
	
		
			
				|  |  |  	if err != nil {
 | 
	
		
			
				|  |  |  		return nil, err
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | +	var (
 | 
	
		
			
				|  |  | +		inet4Address netip.Addr
 | 
	
		
			
				|  |  | +		inet6Address netip.Addr
 | 
	
		
			
				|  |  | +	)
 | 
	
		
			
				|  |  |  	for _, prefix := range options.Address {
 | 
	
		
			
				|  |  |  		addr := tun.AddressFromAddr(prefix.Addr())
 | 
	
		
			
				|  |  |  		protoAddr := tcpip.ProtocolAddress{
 | 
	
	
		
			
				|  | @@ -57,10 +74,12 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) {
 | 
	
		
			
				|  |  |  			},
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  		if prefix.Addr().Is4() {
 | 
	
		
			
				|  |  | -			tunDevice.addr4 = addr
 | 
	
		
			
				|  |  | +			inet4Address = prefix.Addr()
 | 
	
		
			
				|  |  | +			tunDevice.inet4Address = inet4Address
 | 
	
		
			
				|  |  |  			protoAddr.Protocol = ipv4.ProtocolNumber
 | 
	
		
			
				|  |  |  		} else {
 | 
	
		
			
				|  |  | -			tunDevice.addr6 = addr
 | 
	
		
			
				|  |  | +			inet6Address = prefix.Addr()
 | 
	
		
			
				|  |  | +			tunDevice.inet6Address = inet6Address
 | 
	
		
			
				|  |  |  			protoAddr.Protocol = ipv6.ProtocolNumber
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  		gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
 | 
	
	
		
			
				|  | @@ -72,6 +91,10 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) {
 | 
	
		
			
				|  |  |  	if options.Handler != nil {
 | 
	
		
			
				|  |  |  		ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
 | 
	
		
			
				|  |  |  		ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
 | 
	
		
			
				|  |  | +		icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout)
 | 
	
		
			
				|  |  | +		icmpForwarder.SetLocalAddresses(inet4Address, inet6Address)
 | 
	
		
			
				|  |  | +		ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
 | 
	
		
			
				|  |  | +		ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  	return tunDevice, nil
 | 
	
		
			
				|  |  |  }
 | 
	
	
		
			
				|  | @@ -87,11 +110,17 @@ func (w *stackDevice) DialContext(ctx context.Context, network string, destinati
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  	var networkProtocol tcpip.NetworkProtocolNumber
 | 
	
		
			
				|  |  |  	if destination.IsIPv4() {
 | 
	
		
			
				|  |  | +		if !w.inet4Address.IsValid() {
 | 
	
		
			
				|  |  | +			return nil, E.New("missing IPv4 local address")
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  |  		networkProtocol = header.IPv4ProtocolNumber
 | 
	
		
			
				|  |  | -		bind.Addr = w.addr4
 | 
	
		
			
				|  |  | +		bind.Addr = tun.AddressFromAddr(w.inet4Address)
 | 
	
		
			
				|  |  |  	} else {
 | 
	
		
			
				|  |  | +		if !w.inet6Address.IsValid() {
 | 
	
		
			
				|  |  | +			return nil, E.New("missing IPv6 local address")
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  |  		networkProtocol = header.IPv6ProtocolNumber
 | 
	
		
			
				|  |  | -		bind.Addr = w.addr6
 | 
	
		
			
				|  |  | +		bind.Addr = tun.AddressFromAddr(w.inet6Address)
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  	switch N.NetworkName(network) {
 | 
	
		
			
				|  |  |  	case N.NetworkTCP:
 | 
	
	
		
			
				|  | @@ -118,10 +147,10 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
 | 
	
		
			
				|  |  |  	var networkProtocol tcpip.NetworkProtocolNumber
 | 
	
		
			
				|  |  |  	if destination.IsIPv4() {
 | 
	
		
			
				|  |  |  		networkProtocol = header.IPv4ProtocolNumber
 | 
	
		
			
				|  |  | -		bind.Addr = w.addr4
 | 
	
		
			
				|  |  | +		bind.Addr = tun.AddressFromAddr(w.inet4Address)
 | 
	
		
			
				|  |  |  	} else {
 | 
	
		
			
				|  |  |  		networkProtocol = header.IPv6ProtocolNumber
 | 
	
		
			
				|  |  | -		bind.Addr = w.addr6
 | 
	
		
			
				|  |  | +		bind.Addr = tun.AddressFromAddr(w.inet4Address)
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  	udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
 | 
	
		
			
				|  |  |  	if err != nil {
 | 
	
	
		
			
				|  | @@ -130,6 +159,14 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
 | 
	
		
			
				|  |  |  	return udpConn, nil
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +func (w *stackDevice) Inet4Address() netip.Addr {
 | 
	
		
			
				|  |  | +	return w.inet4Address
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func (w *stackDevice) Inet6Address() netip.Addr {
 | 
	
		
			
				|  |  | +	return w.inet6Address
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  func (w *stackDevice) SetDevice(device *device.Device) {
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -144,20 +181,24 @@ func (w *stackDevice) File() *os.File {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
 | 
	
		
			
				|  |  |  	select {
 | 
	
		
			
				|  |  | -	case packetBuffer, ok := <-w.outbound:
 | 
	
		
			
				|  |  | +	case packet, ok := <-w.outbound:
 | 
	
		
			
				|  |  |  		if !ok {
 | 
	
		
			
				|  |  |  			return 0, os.ErrClosed
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  | -		defer packetBuffer.DecRef()
 | 
	
		
			
				|  |  | -		p := bufs[0]
 | 
	
		
			
				|  |  | -		p = p[offset:]
 | 
	
		
			
				|  |  | -		n := 0
 | 
	
		
			
				|  |  | -		for _, slice := range packetBuffer.AsSlices() {
 | 
	
		
			
				|  |  | -			n += copy(p[n:], slice)
 | 
	
		
			
				|  |  | +		defer packet.DecRef()
 | 
	
		
			
				|  |  | +		var copyN int
 | 
	
		
			
				|  |  | +		/*rangeIterate(packet.Data().AsRange(), func(view *buffer.View) {
 | 
	
		
			
				|  |  | +			copyN += copy(bufs[0][offset+copyN:], view.AsSlice())
 | 
	
		
			
				|  |  | +		})*/
 | 
	
		
			
				|  |  | +		for _, view := range packet.AsSlices() {
 | 
	
		
			
				|  |  | +			copyN += copy(bufs[0][offset+copyN:], view)
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  | -		sizes[0] = n
 | 
	
		
			
				|  |  | -		count = 1
 | 
	
		
			
				|  |  | -		return
 | 
	
		
			
				|  |  | +		sizes[0] = copyN
 | 
	
		
			
				|  |  | +		return 1, nil
 | 
	
		
			
				|  |  | +	case packet := <-w.packetOutbound:
 | 
	
		
			
				|  |  | +		defer packet.Release()
 | 
	
		
			
				|  |  | +		sizes[0] = copy(bufs[0][offset:], packet.Bytes())
 | 
	
		
			
				|  |  | +		return 1, nil
 | 
	
		
			
				|  |  |  	case <-w.done:
 | 
	
		
			
				|  |  |  		return 0, os.ErrClosed
 | 
	
		
			
				|  |  |  	}
 | 
	
	
		
			
				|  | @@ -217,6 +258,23 @@ func (w *stackDevice) BatchSize() int {
 | 
	
		
			
				|  |  |  	return 1
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +func (w *stackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
 | 
	
		
			
				|  |  | +	ctx := log.ContextWithNewID(w.ctx)
 | 
	
		
			
				|  |  | +	destination, err := ping.ConnectGVisor(
 | 
	
		
			
				|  |  | +		ctx, w.logger,
 | 
	
		
			
				|  |  | +		metadata.Source.Addr, metadata.Destination.Addr,
 | 
	
		
			
				|  |  | +		routeContext,
 | 
	
		
			
				|  |  | +		w.stack,
 | 
	
		
			
				|  |  | +		w.inet4Address, w.inet6Address,
 | 
	
		
			
				|  |  | +		timeout,
 | 
	
		
			
				|  |  | +	)
 | 
	
		
			
				|  |  | +	if err != nil {
 | 
	
		
			
				|  |  | +		return nil, err
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString())
 | 
	
		
			
				|  |  | +	return destination, nil
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  type wireEndpoint stackDevice
 |