|
@@ -5,7 +5,9 @@ package wireguard
|
|
|
import (
|
|
|
"context"
|
|
|
"net"
|
|
|
+ "net/netip"
|
|
|
"os"
|
|
|
+ "time"
|
|
|
|
|
|
"github.com/sagernet/gvisor/pkg/buffer"
|
|
|
"github.com/sagernet/gvisor/pkg/tcpip"
|
|
@@ -14,9 +16,14 @@ import (
|
|
|
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
|
|
|
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
|
|
|
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
|
|
+ "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
|
|
|
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
|
|
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
|
|
+ "github.com/sagernet/sing-box/adapter"
|
|
|
+ "github.com/sagernet/sing-box/log"
|
|
|
"github.com/sagernet/sing-tun"
|
|
|
+ "github.com/sagernet/sing-tun/ping"
|
|
|
+ "github.com/sagernet/sing/common/buf"
|
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
|
N "github.com/sagernet/sing/common/network"
|
|
@@ -24,30 +31,40 @@ import (
|
|
|
wgTun "github.com/sagernet/wireguard-go/tun"
|
|
|
)
|
|
|
|
|
|
-var _ Device = (*stackDevice)(nil)
|
|
|
+var _ NatDevice = (*stackDevice)(nil)
|
|
|
|
|
|
type stackDevice struct {
|
|
|
- stack *stack.Stack
|
|
|
- mtu uint32
|
|
|
- events chan wgTun.Event
|
|
|
- outbound chan *stack.PacketBuffer
|
|
|
- done chan struct{}
|
|
|
- dispatcher stack.NetworkDispatcher
|
|
|
- addr4 tcpip.Address
|
|
|
- addr6 tcpip.Address
|
|
|
+ ctx context.Context
|
|
|
+ logger log.ContextLogger
|
|
|
+ stack *stack.Stack
|
|
|
+ mtu uint32
|
|
|
+ events chan wgTun.Event
|
|
|
+ outbound chan *stack.PacketBuffer
|
|
|
+ packetOutbound chan *buf.Buffer
|
|
|
+ done chan struct{}
|
|
|
+ dispatcher stack.NetworkDispatcher
|
|
|
+ inet4Address netip.Addr
|
|
|
+ inet6Address netip.Addr
|
|
|
}
|
|
|
|
|
|
func newStackDevice(options DeviceOptions) (*stackDevice, error) {
|
|
|
tunDevice := &stackDevice{
|
|
|
- mtu: options.MTU,
|
|
|
- events: make(chan wgTun.Event, 1),
|
|
|
- outbound: make(chan *stack.PacketBuffer, 256),
|
|
|
- done: make(chan struct{}),
|
|
|
+ ctx: options.Context,
|
|
|
+ logger: options.Logger,
|
|
|
+ mtu: options.MTU,
|
|
|
+ events: make(chan wgTun.Event, 1),
|
|
|
+ outbound: make(chan *stack.PacketBuffer, 256),
|
|
|
+ packetOutbound: make(chan *buf.Buffer, 256),
|
|
|
+ done: make(chan struct{}),
|
|
|
}
|
|
|
- ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
|
|
|
+ ipStack, err := tun.NewGVisorStackWithOptions((*wireEndpoint)(tunDevice), stack.NICOptions{}, true)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
+ var (
|
|
|
+ inet4Address netip.Addr
|
|
|
+ inet6Address netip.Addr
|
|
|
+ )
|
|
|
for _, prefix := range options.Address {
|
|
|
addr := tun.AddressFromAddr(prefix.Addr())
|
|
|
protoAddr := tcpip.ProtocolAddress{
|
|
@@ -57,10 +74,12 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) {
|
|
|
},
|
|
|
}
|
|
|
if prefix.Addr().Is4() {
|
|
|
- tunDevice.addr4 = addr
|
|
|
+ inet4Address = prefix.Addr()
|
|
|
+ tunDevice.inet4Address = inet4Address
|
|
|
protoAddr.Protocol = ipv4.ProtocolNumber
|
|
|
} else {
|
|
|
- tunDevice.addr6 = addr
|
|
|
+ inet6Address = prefix.Addr()
|
|
|
+ tunDevice.inet6Address = inet6Address
|
|
|
protoAddr.Protocol = ipv6.ProtocolNumber
|
|
|
}
|
|
|
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
|
|
@@ -72,6 +91,10 @@ func newStackDevice(options DeviceOptions) (*stackDevice, error) {
|
|
|
if options.Handler != nil {
|
|
|
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
|
|
|
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
|
|
|
+ icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout)
|
|
|
+ icmpForwarder.SetLocalAddresses(inet4Address, inet6Address)
|
|
|
+ ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
|
|
|
+ ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
|
|
|
}
|
|
|
return tunDevice, nil
|
|
|
}
|
|
@@ -87,11 +110,17 @@ func (w *stackDevice) DialContext(ctx context.Context, network string, destinati
|
|
|
}
|
|
|
var networkProtocol tcpip.NetworkProtocolNumber
|
|
|
if destination.IsIPv4() {
|
|
|
+ if !w.inet4Address.IsValid() {
|
|
|
+ return nil, E.New("missing IPv4 local address")
|
|
|
+ }
|
|
|
networkProtocol = header.IPv4ProtocolNumber
|
|
|
- bind.Addr = w.addr4
|
|
|
+ bind.Addr = tun.AddressFromAddr(w.inet4Address)
|
|
|
} else {
|
|
|
+ if !w.inet6Address.IsValid() {
|
|
|
+ return nil, E.New("missing IPv6 local address")
|
|
|
+ }
|
|
|
networkProtocol = header.IPv6ProtocolNumber
|
|
|
- bind.Addr = w.addr6
|
|
|
+ bind.Addr = tun.AddressFromAddr(w.inet6Address)
|
|
|
}
|
|
|
switch N.NetworkName(network) {
|
|
|
case N.NetworkTCP:
|
|
@@ -118,10 +147,10 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
|
|
|
var networkProtocol tcpip.NetworkProtocolNumber
|
|
|
if destination.IsIPv4() {
|
|
|
networkProtocol = header.IPv4ProtocolNumber
|
|
|
- bind.Addr = w.addr4
|
|
|
+ bind.Addr = tun.AddressFromAddr(w.inet4Address)
|
|
|
} else {
|
|
|
networkProtocol = header.IPv6ProtocolNumber
|
|
|
- bind.Addr = w.addr6
|
|
|
+ bind.Addr = tun.AddressFromAddr(w.inet4Address)
|
|
|
}
|
|
|
udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
|
|
|
if err != nil {
|
|
@@ -130,6 +159,14 @@ func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
|
|
|
return udpConn, nil
|
|
|
}
|
|
|
|
|
|
+func (w *stackDevice) Inet4Address() netip.Addr {
|
|
|
+ return w.inet4Address
|
|
|
+}
|
|
|
+
|
|
|
+func (w *stackDevice) Inet6Address() netip.Addr {
|
|
|
+ return w.inet6Address
|
|
|
+}
|
|
|
+
|
|
|
func (w *stackDevice) SetDevice(device *device.Device) {
|
|
|
}
|
|
|
|
|
@@ -144,20 +181,24 @@ func (w *stackDevice) File() *os.File {
|
|
|
|
|
|
func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
|
|
|
select {
|
|
|
- case packetBuffer, ok := <-w.outbound:
|
|
|
+ case packet, ok := <-w.outbound:
|
|
|
if !ok {
|
|
|
return 0, os.ErrClosed
|
|
|
}
|
|
|
- defer packetBuffer.DecRef()
|
|
|
- p := bufs[0]
|
|
|
- p = p[offset:]
|
|
|
- n := 0
|
|
|
- for _, slice := range packetBuffer.AsSlices() {
|
|
|
- n += copy(p[n:], slice)
|
|
|
+ defer packet.DecRef()
|
|
|
+ var copyN int
|
|
|
+ /*rangeIterate(packet.Data().AsRange(), func(view *buffer.View) {
|
|
|
+ copyN += copy(bufs[0][offset+copyN:], view.AsSlice())
|
|
|
+ })*/
|
|
|
+ for _, view := range packet.AsSlices() {
|
|
|
+ copyN += copy(bufs[0][offset+copyN:], view)
|
|
|
}
|
|
|
- sizes[0] = n
|
|
|
- count = 1
|
|
|
- return
|
|
|
+ sizes[0] = copyN
|
|
|
+ return 1, nil
|
|
|
+ case packet := <-w.packetOutbound:
|
|
|
+ defer packet.Release()
|
|
|
+ sizes[0] = copy(bufs[0][offset:], packet.Bytes())
|
|
|
+ return 1, nil
|
|
|
case <-w.done:
|
|
|
return 0, os.ErrClosed
|
|
|
}
|
|
@@ -217,6 +258,23 @@ func (w *stackDevice) BatchSize() int {
|
|
|
return 1
|
|
|
}
|
|
|
|
|
|
+func (w *stackDevice) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
|
|
|
+ ctx := log.ContextWithNewID(w.ctx)
|
|
|
+ destination, err := ping.ConnectGVisor(
|
|
|
+ ctx, w.logger,
|
|
|
+ metadata.Source.Addr, metadata.Destination.Addr,
|
|
|
+ routeContext,
|
|
|
+ w.stack,
|
|
|
+ w.inet4Address, w.inet6Address,
|
|
|
+ timeout,
|
|
|
+ )
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString())
|
|
|
+ return destination, nil
|
|
|
+}
|
|
|
+
|
|
|
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
|
|
|
|
|
|
type wireEndpoint stackDevice
|