浏览代码

Refactor wireguard & add tun support

世界 3 年之前
父节点
当前提交
cb4fea0240

+ 1 - 2
docs/configuration/outbound/wireguard.md

@@ -8,7 +8,6 @@
   "server": "127.0.0.1",
   "server_port": 1080,
   "local_address": [
-    "10.0.0.1",
     "10.0.0.2/32"
   ],
   "private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=",
@@ -43,7 +42,7 @@ The server port.
 
 ==Required==
 
-List of IP (v4 or v6) addresses (optionally with CIDR masks) to be assigned to the interface.
+List of IP (v4 or v6) address prefixes to be assigned to the interface.
 
 #### private_key
 

+ 1 - 2
docs/configuration/outbound/wireguard.zh.md

@@ -8,7 +8,6 @@
   "server": "127.0.0.1",
   "server_port": 1080,
   "local_address": [
-    "10.0.0.1",
     "10.0.0.2/32"
   ],
   "private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=",
@@ -45,7 +44,7 @@
 
 接口的 IPv4/IPv6 地址或地址段的列表您。
 
-要分配给接口的 IP(v4 或 v6)地址列表(可以选择带有 CIDR 掩码)
+要分配给接口的 IP(v4 或 v6)地址列表。
 
 #### private_key
 

+ 4 - 3
inbound/tun.go

@@ -40,7 +40,7 @@ type Tun struct {
 func NewTun(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunInboundOptions) (*Tun, error) {
 	tunName := options.InterfaceName
 	if tunName == "" {
-		tunName = tun.DefaultInterfaceName()
+		tunName = tun.CalculateInterfaceName("")
 	}
 	tunMTU := options.MTU
 	if tunMTU == 0 {
@@ -77,8 +77,8 @@ func NewTun(ctx context.Context, router adapter.Router, logger log.ContextLogger
 		tunOptions: tun.Options{
 			Name:               tunName,
 			MTU:                tunMTU,
-			Inet4Address:       options.Inet4Address.Build(),
-			Inet6Address:       options.Inet6Address.Build(),
+			Inet4Address:       common.Map(options.Inet4Address, option.ListenPrefix.Build),
+			Inet6Address:       common.Map(options.Inet6Address, option.ListenPrefix.Build),
 			AutoRoute:          options.AutoRoute,
 			StrictRoute:        options.StrictRoute,
 			IncludeUID:         includeUID,
@@ -87,6 +87,7 @@ func NewTun(ctx context.Context, router adapter.Router, logger log.ContextLogger
 			IncludePackage:     options.IncludePackage,
 			ExcludePackage:     options.ExcludePackage,
 			InterfaceMonitor:   router.InterfaceMonitor(),
+			TableIndex:         2022,
 		},
 		endpointIndependentNat: options.EndpointIndependentNat,
 		udpTimeout:             udpTimeout,

+ 1 - 0
option/outbound.go

@@ -108,6 +108,7 @@ type DialerOptions struct {
 	Detour         string         `json:"detour,omitempty"`
 	BindInterface  string         `json:"bind_interface,omitempty"`
 	BindAddress    *ListenAddress `json:"bind_address,omitempty"`
+	BindAddress6   *ListenAddress `json:"bind_address6,omitempty"`
 	ProtectPath    string         `json:"protect_path,omitempty"`
 	RoutingMark    int            `json:"routing_mark,omitempty"`
 	ReuseAddr      bool           `json:"reuse_addr,omitempty"`

+ 16 - 16
option/tun.go

@@ -1,21 +1,21 @@
 package option
 
 type TunInboundOptions struct {
-	InterfaceName          string           `json:"interface_name,omitempty"`
-	MTU                    uint32           `json:"mtu,omitempty"`
-	Inet4Address           *ListenPrefix    `json:"inet4_address,omitempty"`
-	Inet6Address           *ListenPrefix    `json:"inet6_address,omitempty"`
-	AutoRoute              bool             `json:"auto_route,omitempty"`
-	StrictRoute            bool             `json:"strict_route,omitempty"`
-	IncludeUID             Listable[uint32] `json:"include_uid,omitempty"`
-	IncludeUIDRange        Listable[string] `json:"include_uid_range,omitempty"`
-	ExcludeUID             Listable[uint32] `json:"exclude_uid,omitempty"`
-	ExcludeUIDRange        Listable[string] `json:"exclude_uid_range,omitempty"`
-	IncludeAndroidUser     Listable[int]    `json:"include_android_user,omitempty"`
-	IncludePackage         Listable[string] `json:"include_package,omitempty"`
-	ExcludePackage         Listable[string] `json:"exclude_package,omitempty"`
-	EndpointIndependentNat bool             `json:"endpoint_independent_nat,omitempty"`
-	UDPTimeout             int64            `json:"udp_timeout,omitempty"`
-	Stack                  string           `json:"stack,omitempty"`
+	InterfaceName          string                 `json:"interface_name,omitempty"`
+	MTU                    uint32                 `json:"mtu,omitempty"`
+	Inet4Address           Listable[ListenPrefix] `json:"inet4_address,omitempty"`
+	Inet6Address           Listable[ListenPrefix] `json:"inet6_address,omitempty"`
+	AutoRoute              bool                   `json:"auto_route,omitempty"`
+	StrictRoute            bool                   `json:"strict_route,omitempty"`
+	IncludeUID             Listable[uint32]       `json:"include_uid,omitempty"`
+	IncludeUIDRange        Listable[string]       `json:"include_uid_range,omitempty"`
+	ExcludeUID             Listable[uint32]       `json:"exclude_uid,omitempty"`
+	ExcludeUIDRange        Listable[string]       `json:"exclude_uid_range,omitempty"`
+	IncludeAndroidUser     Listable[int]          `json:"include_android_user,omitempty"`
+	IncludePackage         Listable[string]       `json:"include_package,omitempty"`
+	ExcludePackage         Listable[string]       `json:"exclude_package,omitempty"`
+	EndpointIndependentNat bool                   `json:"endpoint_independent_nat,omitempty"`
+	UDPTimeout             int64                  `json:"udp_timeout,omitempty"`
+	Stack                  string                 `json:"stack,omitempty"`
 	InboundOptions
 }

+ 2 - 5
option/types.go

@@ -184,9 +184,6 @@ func (p *ListenPrefix) UnmarshalJSON(bytes []byte) error {
 	return nil
 }
 
-func (p *ListenPrefix) Build() netip.Prefix {
-	if p == nil {
-		return netip.Prefix{}
-	}
-	return netip.Prefix(*p)
+func (p ListenPrefix) Build() netip.Prefix {
+	return netip.Prefix(p)
 }

+ 8 - 6
option/wireguard.go

@@ -3,10 +3,12 @@ package option
 type WireGuardOutboundOptions struct {
 	DialerOptions
 	ServerOptions
-	LocalAddress  Listable[string] `json:"local_address"`
-	PrivateKey    string           `json:"private_key"`
-	PeerPublicKey string           `json:"peer_public_key"`
-	PreSharedKey  string           `json:"pre_shared_key,omitempty"`
-	MTU           uint32           `json:"mtu,omitempty"`
-	Network       NetworkList      `json:"network,omitempty"`
+	SystemInterface bool                   `json:"system_interface,omitempty"`
+	InterfaceName   string                 `json:"interface_name,omitempty"`
+	LocalAddress    Listable[ListenPrefix] `json:"local_address"`
+	PrivateKey      string                 `json:"private_key"`
+	PeerPublicKey   string                 `json:"peer_public_key"`
+	PreSharedKey    string                 `json:"pre_shared_key,omitempty"`
+	MTU             uint32                 `json:"mtu,omitempty"`
+	Network         NetworkList            `json:"network,omitempty"`
 }

+ 27 - 388
outbound/wireguard.go

@@ -8,49 +8,30 @@ import (
 	"encoding/hex"
 	"fmt"
 	"net"
-	"net/netip"
-	"os"
 	"strings"
-	"sync"
 
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/common/dialer"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing-box/transport/wireguard"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/debug"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 
-	"golang.zx2c4.com/wireguard/conn"
 	"golang.zx2c4.com/wireguard/device"
-	"golang.zx2c4.com/wireguard/tun"
-	"gvisor.dev/gvisor/pkg/bufferv2"
-	"gvisor.dev/gvisor/pkg/tcpip"
-	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
-	"gvisor.dev/gvisor/pkg/tcpip/header"
-	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
-	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
-	"gvisor.dev/gvisor/pkg/tcpip/stack"
-	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
-	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
-	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
 )
 
 var _ adapter.Outbound = (*WireGuard)(nil)
 
 type WireGuard struct {
 	myOutboundAdapter
-	ctx        context.Context
-	serverAddr M.Socksaddr
-	dialer     N.Dialer
-	endpoint   conn.Endpoint
-	device     *device.Device
-	tunDevice  *wireTunDevice
-	connAccess sync.Mutex
-	conn       *wireConn
+	bind      *wireguard.ClientBind
+	device    *device.Device
+	tunDevice wireguard.Device
 }
 
 func NewWireGuard(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (*WireGuard, error) {
@@ -62,39 +43,13 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
 			logger:   logger,
 			tag:      tag,
 		},
-		ctx:        ctx,
-		serverAddr: options.ServerOptions.Build(),
-		dialer:     dialer.New(router, options.DialerOptions),
 	}
-	var endpointIp netip.Addr
-	if !outbound.serverAddr.IsFqdn() {
-		endpointIp = outbound.serverAddr.Addr
-	} else {
-		endpointIp = netip.AddrFrom4([4]byte{127, 0, 0, 1})
-	}
-	outbound.endpoint = conn.StdNetEndpoint(netip.AddrPortFrom(endpointIp, outbound.serverAddr.Port))
-	localAddress := make([]tcpip.AddressWithPrefix, len(options.LocalAddress))
-	if len(localAddress) == 0 {
+	peerAddr := options.ServerOptions.Build()
+	outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), peerAddr)
+	localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build)
+	if len(localPrefixes) == 0 {
 		return nil, E.New("missing local address")
 	}
-	for index, address := range options.LocalAddress {
-		if strings.Contains(address, "/") {
-			prefix, err := netip.ParsePrefix(address)
-			if err != nil {
-				return nil, E.Cause(err, "parse local address prefix ", address)
-			}
-			localAddress[index] = tcpip.AddressWithPrefix{
-				Address:   tcpip.Address(prefix.Addr().AsSlice()),
-				PrefixLen: prefix.Bits(),
-			}
-		} else {
-			addr, err := netip.ParseAddr(address)
-			if err != nil {
-				return nil, E.Cause(err, "parse local address ", address)
-			}
-			localAddress[index] = tcpip.Address(addr.AsSlice()).WithPrefix()
-		}
-	}
 	var privateKey, peerPublicKey, preSharedKey string
 	{
 		bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
@@ -119,13 +74,13 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
 	}
 	ipcConf := "private_key=" + privateKey
 	ipcConf += "\npublic_key=" + peerPublicKey
-	ipcConf += "\nendpoint=" + outbound.endpoint.DstToString()
+	ipcConf += "\nendpoint=" + peerAddr.String()
 	if preSharedKey != "" {
 		ipcConf += "\npreshared_key=" + preSharedKey
 	}
 	var has4, has6 bool
-	for _, address := range localAddress {
-		if address.Address.To4() != "" {
+	for _, address := range localPrefixes {
+		if address.Addr().Is4() {
 			has4 = true
 		} else {
 			has6 = true
@@ -141,11 +96,17 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
 	if mtu == 0 {
 		mtu = 1408
 	}
-	wireDevice, err := newWireDevice(localAddress, mtu)
+	var wireTunDevice wireguard.Device
+	var err error
+	if !options.SystemInterface {
+		wireTunDevice, err = wireguard.NewStackDevice(localPrefixes, mtu)
+	} else {
+		wireTunDevice, err = wireguard.NewSystemDevice(router, options.InterfaceName, localPrefixes, mtu)
+	}
 	if err != nil {
-		return nil, err
+		return nil, E.Cause(err, "create WireGuard device")
 	}
-	wgDevice := device.NewDevice(wireDevice, (*wireClientBind)(outbound), &device.Logger{
+	wgDevice := device.NewDevice(wireTunDevice, outbound.bind, &device.Logger{
 		Verbosef: func(format string, args ...interface{}) {
 			logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
 		},
@@ -161,7 +122,7 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
 		return nil, E.Cause(err, "setup wireguard")
 	}
 	outbound.device = wgDevice
-	outbound.tunDevice = wireDevice
+	outbound.tunDevice = wireTunDevice
 	return outbound, nil
 }
 
@@ -172,54 +133,19 @@ func (w *WireGuard) DialContext(ctx context.Context, network string, destination
 	case N.NetworkUDP:
 		w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
 	}
-	addr := tcpip.FullAddress{
-		NIC:  defaultNIC,
-		Port: destination.Port,
-	}
 	if destination.IsFqdn() {
 		addrs, err := w.router.LookupDefault(ctx, destination.Fqdn)
 		if err != nil {
 			return nil, err
 		}
-		addr.Addr = tcpip.Address(addrs[0].AsSlice())
-	} else {
-		addr.Addr = tcpip.Address(destination.Addr.AsSlice())
-	}
-	bind := tcpip.FullAddress{
-		NIC: defaultNIC,
-	}
-	var networkProtocol tcpip.NetworkProtocolNumber
-	if destination.IsIPv4() {
-		networkProtocol = header.IPv4ProtocolNumber
-		bind.Addr = w.tunDevice.addr4
-	} else {
-		networkProtocol = header.IPv6ProtocolNumber
-		bind.Addr = w.tunDevice.addr6
-	}
-	switch N.NetworkName(network) {
-	case N.NetworkTCP:
-		return gonet.DialTCPWithBind(ctx, w.tunDevice.stack, bind, addr, networkProtocol)
-	case N.NetworkUDP:
-		return gonet.DialUDP(w.tunDevice.stack, &bind, &addr, networkProtocol)
-	default:
-		return nil, E.Extend(N.ErrUnknownNetwork, network)
+		return N.DialSerial(ctx, w.tunDevice, network, destination, addrs)
 	}
+	return w.tunDevice.DialContext(ctx, network, destination)
 }
 
 func (w *WireGuard) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
 	w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
-	bind := tcpip.FullAddress{
-		NIC: defaultNIC,
-	}
-	var networkProtocol tcpip.NetworkProtocolNumber
-	if destination.IsIPv4() || w.tunDevice.addr6 == "" {
-		networkProtocol = header.IPv4ProtocolNumber
-		bind.Addr = w.tunDevice.addr4
-	} else {
-		networkProtocol = header.IPv6ProtocolNumber
-		bind.Addr = w.tunDevice.addr6
-	}
-	return gonet.DialUDP(w.tunDevice.stack, &bind, nil, networkProtocol)
+	return w.tunDevice.ListenPacket(ctx, destination)
 }
 
 func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
@@ -231,300 +157,13 @@ func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn,
 }
 
 func (w *WireGuard) Start() error {
-	w.tunDevice.events <- tun.EventUp
-	return nil
+	return w.tunDevice.Start()
 }
 
 func (w *WireGuard) Close() error {
 	return common.Close(
-		common.PtrOrNil(w.tunDevice),
+		w.tunDevice,
 		common.PtrOrNil(w.device),
-		common.PtrOrNil(w.conn),
+		common.PtrOrNil(w.bind),
 	)
 }
-
-var _ conn.Bind = (*wireClientBind)(nil)
-
-type wireClientBind WireGuard
-
-func (c *wireClientBind) connect() (*wireConn, error) {
-	c.connAccess.Lock()
-	defer c.connAccess.Unlock()
-	if c.conn != nil {
-		select {
-		case <-c.conn.done:
-		default:
-			return c.conn, nil
-		}
-	}
-	udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr)
-	if err != nil {
-		return nil, &wireError{err}
-	}
-	c.conn = &wireConn{
-		Conn: udpConn,
-		done: make(chan struct{}),
-	}
-	return c.conn, nil
-}
-
-func (c *wireClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
-	return []conn.ReceiveFunc{c.receive}, 0, nil
-}
-
-func (c *wireClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
-	udpConn, err := c.connect()
-	if err != nil {
-		err = &wireError{err}
-		return
-	}
-	n, err = udpConn.Read(b)
-	if err != nil {
-		udpConn.Close()
-		err = &wireError{err}
-	}
-	ep = c.endpoint
-	return
-}
-
-func (c *wireClientBind) Close() error {
-	c.connAccess.Lock()
-	defer c.connAccess.Unlock()
-	common.Close(common.PtrOrNil(c.conn))
-	return nil
-}
-
-func (c *wireClientBind) SetMark(mark uint32) error {
-	return nil
-}
-
-func (c *wireClientBind) Send(b []byte, ep conn.Endpoint) error {
-	udpConn, err := c.connect()
-	if err != nil {
-		return err
-	}
-	_, err = udpConn.Write(b)
-	if err != nil {
-		udpConn.Close()
-	}
-	return err
-}
-
-func (c *wireClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
-	return c.endpoint, nil
-}
-
-type wireError struct {
-	cause error
-}
-
-func (w *wireError) Error() string {
-	return w.cause.Error()
-}
-
-func (w *wireError) Timeout() bool {
-	if cause, causeNet := w.cause.(net.Error); causeNet {
-		return cause.Timeout()
-	}
-	return false
-}
-
-func (w *wireError) Temporary() bool {
-	return true
-}
-
-type wireConn struct {
-	net.Conn
-	access sync.Mutex
-	done   chan struct{}
-}
-
-func (w *wireConn) Close() error {
-	w.access.Lock()
-	defer w.access.Unlock()
-	select {
-	case <-w.done:
-		return net.ErrClosed
-	default:
-	}
-	w.Conn.Close()
-	close(w.done)
-	return nil
-}
-
-var _ tun.Device = (*wireTunDevice)(nil)
-
-const defaultNIC tcpip.NICID = 1
-
-type wireTunDevice struct {
-	stack      *stack.Stack
-	mtu        uint32
-	events     chan tun.Event
-	outbound   chan *stack.PacketBuffer
-	dispatcher stack.NetworkDispatcher
-	done       chan struct{}
-	addr4      tcpip.Address
-	addr6      tcpip.Address
-}
-
-func newWireDevice(localAddresses []tcpip.AddressWithPrefix, mtu uint32) (*wireTunDevice, error) {
-	ipStack := stack.New(stack.Options{
-		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
-		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
-		HandleLocal:        true,
-	})
-	tunDevice := &wireTunDevice{
-		stack:    ipStack,
-		mtu:      mtu,
-		events:   make(chan tun.Event, 4),
-		outbound: make(chan *stack.PacketBuffer, 256),
-		done:     make(chan struct{}),
-	}
-	err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
-	if err != nil {
-		return nil, E.New(err.String())
-	}
-	for _, addr := range localAddresses {
-		var protoAddr tcpip.ProtocolAddress
-		if len(addr.Address) == net.IPv4len {
-			tunDevice.addr4 = addr.Address
-			protoAddr = tcpip.ProtocolAddress{
-				Protocol:          ipv4.ProtocolNumber,
-				AddressWithPrefix: addr,
-			}
-		} else {
-			tunDevice.addr6 = addr.Address
-			protoAddr = tcpip.ProtocolAddress{
-				Protocol:          ipv6.ProtocolNumber,
-				AddressWithPrefix: addr,
-			}
-		}
-		err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
-		if err != nil {
-			return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
-		}
-	}
-	sOpt := tcpip.TCPSACKEnabled(true)
-	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
-	cOpt := tcpip.CongestionControlOption("cubic")
-	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
-	ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
-	ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
-	return tunDevice, nil
-}
-
-func (w *wireTunDevice) File() *os.File {
-	return nil
-}
-
-func (w *wireTunDevice) Read(p []byte, offset int) (n int, err error) {
-	packetBuffer, ok := <-w.outbound
-	if !ok {
-		return 0, os.ErrClosed
-	}
-	defer packetBuffer.DecRef()
-	p = p[offset:]
-	for _, slice := range packetBuffer.AsSlices() {
-		n += copy(p[n:], slice)
-	}
-	return
-}
-
-func (w *wireTunDevice) Write(p []byte, offset int) (n int, err error) {
-	p = p[offset:]
-	if len(p) == 0 {
-		return
-	}
-	var networkProtocol tcpip.NetworkProtocolNumber
-	switch header.IPVersion(p) {
-	case header.IPv4Version:
-		networkProtocol = header.IPv4ProtocolNumber
-	case header.IPv6Version:
-		networkProtocol = header.IPv6ProtocolNumber
-	}
-	packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
-		Payload: bufferv2.MakeWithData(p),
-	})
-	defer packetBuffer.DecRef()
-	w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
-	n = len(p)
-	return
-}
-
-func (w *wireTunDevice) Flush() error {
-	return nil
-}
-
-func (w *wireTunDevice) MTU() (int, error) {
-	return int(w.mtu), nil
-}
-
-func (w *wireTunDevice) Name() (string, error) {
-	return "sing-box", nil
-}
-
-func (w *wireTunDevice) Events() chan tun.Event {
-	return w.events
-}
-
-func (w *wireTunDevice) Close() error {
-	select {
-	case <-w.done:
-		return os.ErrClosed
-	default:
-	}
-	close(w.done)
-	w.stack.Close()
-	for _, endpoint := range w.stack.CleanupEndpoints() {
-		endpoint.Abort()
-	}
-	w.stack.Wait()
-	close(w.outbound)
-	return nil
-}
-
-var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
-
-type wireEndpoint wireTunDevice
-
-func (ep *wireEndpoint) MTU() uint32 {
-	return ep.mtu
-}
-
-func (ep *wireEndpoint) MaxHeaderLength() uint16 {
-	return 0
-}
-
-func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
-	return ""
-}
-
-func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
-	return stack.CapabilityNone
-}
-
-func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
-	ep.dispatcher = dispatcher
-}
-
-func (ep *wireEndpoint) IsAttached() bool {
-	return ep.dispatcher != nil
-}
-
-func (ep *wireEndpoint) Wait() {
-}
-
-func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
-	return header.ARPHardwareNone
-}
-
-func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
-}
-
-func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
-	for _, packetBuffer := range list.AsSlice() {
-		packetBuffer.IncRef()
-		ep.outbound <- packetBuffer
-	}
-	return list.Len(), nil
-}

+ 1 - 1
test/config/wireguard.conf

@@ -1,6 +1,6 @@
 [Interface]
 PrivateKey = gHWUGzTh5YCEV6k8dneVP537XhVtoQJPIlFNs2zsxlE=
-Address = 10.0.0.1/24
+Address = 10.0.0.1/32
 ListenPort = 10000
 
 [Peer]

+ 48 - 1
test/wireguard_test.go

@@ -2,6 +2,7 @@ package main
 
 import (
 	"net/netip"
+	"os"
 	"testing"
 	"time"
 
@@ -40,7 +41,7 @@ func TestWireGuard(t *testing.T) {
 						Server:     "127.0.0.1",
 						ServerPort: serverPort,
 					},
-					LocalAddress:  []string{"10.0.0.2/32"},
+					LocalAddress:  []option.ListenPrefix{option.ListenPrefix(netip.MustParsePrefix("10.0.0.2/32"))},
 					PrivateKey:    "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=",
 					PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=",
 				},
@@ -49,3 +50,49 @@ func TestWireGuard(t *testing.T) {
 	})
 	testSuitWg(t, clientPort, testPort)
 }
+
+func TestWireGuardSystem(t *testing.T) {
+	if os.Getuid() != 0 {
+		t.Skip("requires root")
+	}
+	startDockerContainer(t, DockerOptions{
+		Image: ImageBoringTun,
+		Cap:   []string{"MKNOD", "NET_ADMIN", "NET_RAW"},
+		Ports: []uint16{serverPort, testPort},
+		Bind: map[string]string{
+			"wireguard.conf": "/etc/wireguard/wg0.conf",
+		},
+		Cmd: []string{"wg0"},
+	})
+	time.Sleep(5 * time.Second)
+	startInstance(t, option.Options{
+		Inbounds: []option.Inbound{
+			{
+				Type: C.TypeMixed,
+				MixedOptions: option.HTTPMixedInboundOptions{
+					ListenOptions: option.ListenOptions{
+						Listen:     option.ListenAddress(netip.IPv4Unspecified()),
+						ListenPort: clientPort,
+					},
+				},
+			},
+		},
+		Outbounds: []option.Outbound{
+			{
+				Type: C.TypeWireGuard,
+				WireGuardOptions: option.WireGuardOutboundOptions{
+					InterfaceName: "wg",
+					ServerOptions: option.ServerOptions{
+						Server:     "127.0.0.1",
+						ServerPort: serverPort,
+					},
+					LocalAddress:  []option.ListenPrefix{option.ListenPrefix(netip.MustParsePrefix("10.0.0.2/32"))},
+					PrivateKey:    "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=",
+					PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=",
+				},
+			},
+		},
+	})
+	time.Sleep(10 * time.Second)
+	testSuitWg(t, clientPort, testPort)
+}

+ 132 - 0
transport/wireguard/client_bind.go

@@ -0,0 +1,132 @@
+package wireguard
+
+import (
+	"context"
+	"net"
+	"sync"
+
+	"github.com/sagernet/sing/common"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+
+	"golang.zx2c4.com/wireguard/conn"
+)
+
+var _ conn.Bind = (*ClientBind)(nil)
+
+type ClientBind struct {
+	ctx        context.Context
+	dialer     N.Dialer
+	peerAddr   M.Socksaddr
+	connAccess sync.Mutex
+	conn       *wireConn
+}
+
+func NewClientBind(ctx context.Context, dialer N.Dialer, peerAddr M.Socksaddr) *ClientBind {
+	return &ClientBind{
+		ctx:      ctx,
+		dialer:   dialer,
+		peerAddr: peerAddr,
+	}
+}
+
+func (c *ClientBind) connect() (*wireConn, error) {
+	serverConn := c.conn
+	if serverConn != nil {
+		select {
+		case <-serverConn.done:
+			serverConn = nil
+		default:
+			return serverConn, nil
+		}
+	}
+	c.connAccess.Lock()
+	defer c.connAccess.Unlock()
+	serverConn = c.conn
+	if serverConn != nil {
+		select {
+		case <-serverConn.done:
+			serverConn = nil
+		default:
+			return serverConn, nil
+		}
+	}
+	udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.peerAddr)
+	if err != nil {
+		return nil, &wireError{err}
+	}
+	c.conn = &wireConn{
+		Conn: udpConn,
+		done: make(chan struct{}),
+	}
+	return c.conn, nil
+}
+
+func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+	return []conn.ReceiveFunc{c.receive}, 0, nil
+}
+
+func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
+	udpConn, err := c.connect()
+	if err != nil {
+		err = &wireError{err}
+		return
+	}
+	n, err = udpConn.Read(b)
+	if err != nil {
+		udpConn.Close()
+		err = &wireError{err}
+	}
+	ep = Endpoint(c.peerAddr)
+	return
+}
+
+func (c *ClientBind) Close() error {
+	c.connAccess.Lock()
+	defer c.connAccess.Unlock()
+	common.Close(common.PtrOrNil(c.conn))
+	return nil
+}
+
+func (c *ClientBind) SetMark(mark uint32) error {
+	return nil
+}
+
+func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
+	udpConn, err := c.connect()
+	if err != nil {
+		return err
+	}
+	_, err = udpConn.Write(b)
+	if err != nil {
+		udpConn.Close()
+	}
+	return err
+}
+
+func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
+	return Endpoint(c.peerAddr), nil
+}
+
+func (c *ClientBind) Endpoint() conn.Endpoint {
+	return Endpoint(c.peerAddr)
+}
+
+type wireConn struct {
+	net.Conn
+	access sync.Mutex
+	done   chan struct{}
+}
+
+func (w *wireConn) Close() error {
+	w.access.Lock()
+	defer w.access.Unlock()
+	select {
+	case <-w.done:
+		return net.ErrClosed
+	default:
+	}
+	w.Conn.Close()
+	close(w.done)
+	return nil
+}

+ 14 - 0
transport/wireguard/device.go

@@ -0,0 +1,14 @@
+package wireguard
+
+import (
+	N "github.com/sagernet/sing/common/network"
+
+	"golang.zx2c4.com/wireguard/tun"
+)
+
+type Device interface {
+	tun.Device
+	N.Dialer
+	Start() error
+	// NewEndpoint() (stack.LinkEndpoint, error)
+}

+ 254 - 0
transport/wireguard/device_stack.go

@@ -0,0 +1,254 @@
+//go:build !no_gvisor
+
+package wireguard
+
+import (
+	"context"
+	"net"
+	"net/netip"
+	"os"
+
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+
+	"golang.zx2c4.com/wireguard/tun"
+	"gvisor.dev/gvisor/pkg/bufferv2"
+	"gvisor.dev/gvisor/pkg/tcpip"
+	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+	"gvisor.dev/gvisor/pkg/tcpip/header"
+	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+	"gvisor.dev/gvisor/pkg/tcpip/stack"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+var _ Device = (*StackDevice)(nil)
+
+const defaultNIC tcpip.NICID = 1
+
+type StackDevice struct {
+	stack      *stack.Stack
+	mtu        uint32
+	events     chan tun.Event
+	outbound   chan *stack.PacketBuffer
+	dispatcher stack.NetworkDispatcher
+	done       chan struct{}
+	addr4      tcpip.Address
+	addr6      tcpip.Address
+}
+
+func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
+	ipStack := stack.New(stack.Options{
+		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
+		HandleLocal:        true,
+	})
+	tunDevice := &StackDevice{
+		stack:    ipStack,
+		mtu:      mtu,
+		events:   make(chan tun.Event, 1),
+		outbound: make(chan *stack.PacketBuffer, 256),
+		done:     make(chan struct{}),
+	}
+	err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
+	if err != nil {
+		return nil, E.New(err.String())
+	}
+	for _, prefix := range localAddresses {
+		addr := tcpip.Address(prefix.Addr().AsSlice())
+		protoAddr := tcpip.ProtocolAddress{
+			AddressWithPrefix: tcpip.AddressWithPrefix{
+				Address:   addr,
+				PrefixLen: prefix.Bits(),
+			},
+		}
+		if prefix.Addr().Is4() {
+			tunDevice.addr4 = addr
+			protoAddr.Protocol = ipv4.ProtocolNumber
+		} else {
+			tunDevice.addr6 = addr
+			protoAddr.Protocol = ipv6.ProtocolNumber
+		}
+		err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
+		if err != nil {
+			return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
+		}
+	}
+	sOpt := tcpip.TCPSACKEnabled(true)
+	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
+	cOpt := tcpip.CongestionControlOption("cubic")
+	ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
+	ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
+	ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
+	return tunDevice, nil
+}
+
+func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
+	return (*wireEndpoint)(w), nil
+}
+
+func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
+	addr := tcpip.FullAddress{
+		NIC:  defaultNIC,
+		Port: destination.Port,
+		Addr: tcpip.Address(destination.Addr.AsSlice()),
+	}
+	bind := tcpip.FullAddress{
+		NIC: defaultNIC,
+	}
+	var networkProtocol tcpip.NetworkProtocolNumber
+	if destination.IsIPv4() {
+		networkProtocol = header.IPv4ProtocolNumber
+		bind.Addr = w.addr4
+	} else {
+		networkProtocol = header.IPv6ProtocolNumber
+		bind.Addr = w.addr6
+	}
+	switch N.NetworkName(network) {
+	case N.NetworkTCP:
+		return gonet.DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
+	case N.NetworkUDP:
+		return gonet.DialUDP(w.stack, &bind, &addr, networkProtocol)
+	default:
+		return nil, E.Extend(N.ErrUnknownNetwork, network)
+	}
+}
+
+func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
+	bind := tcpip.FullAddress{
+		NIC: defaultNIC,
+	}
+	var networkProtocol tcpip.NetworkProtocolNumber
+	if destination.IsIPv4() || w.addr6 == "" {
+		networkProtocol = header.IPv4ProtocolNumber
+		bind.Addr = w.addr4
+	} else {
+		networkProtocol = header.IPv6ProtocolNumber
+		bind.Addr = w.addr6
+	}
+	return gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
+}
+
+func (w *StackDevice) Start() error {
+	w.events <- tun.EventUp
+	return nil
+}
+
+func (w *StackDevice) File() *os.File {
+	return nil
+}
+
+func (w *StackDevice) Read(p []byte, offset int) (n int, err error) {
+	packetBuffer, ok := <-w.outbound
+	if !ok {
+		return 0, os.ErrClosed
+	}
+	defer packetBuffer.DecRef()
+	p = p[offset:]
+	for _, slice := range packetBuffer.AsSlices() {
+		n += copy(p[n:], slice)
+	}
+	return
+}
+
+func (w *StackDevice) Write(p []byte, offset int) (n int, err error) {
+	p = p[offset:]
+	if len(p) == 0 {
+		return
+	}
+	var networkProtocol tcpip.NetworkProtocolNumber
+	switch header.IPVersion(p) {
+	case header.IPv4Version:
+		networkProtocol = header.IPv4ProtocolNumber
+	case header.IPv6Version:
+		networkProtocol = header.IPv6ProtocolNumber
+	}
+	packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
+		Payload: bufferv2.MakeWithData(p),
+	})
+	defer packetBuffer.DecRef()
+	w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
+	n = len(p)
+	return
+}
+
+func (w *StackDevice) Flush() error {
+	return nil
+}
+
+func (w *StackDevice) MTU() (int, error) {
+	return int(w.mtu), nil
+}
+
+func (w *StackDevice) Name() (string, error) {
+	return "sing-box", nil
+}
+
+func (w *StackDevice) Events() chan tun.Event {
+	return w.events
+}
+
+func (w *StackDevice) Close() error {
+	select {
+	case <-w.done:
+		return os.ErrClosed
+	default:
+	}
+	close(w.done)
+	w.stack.Close()
+	for _, endpoint := range w.stack.CleanupEndpoints() {
+		endpoint.Abort()
+	}
+	w.stack.Wait()
+	close(w.outbound)
+	return nil
+}
+
+var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
+
+type wireEndpoint StackDevice
+
+func (ep *wireEndpoint) MTU() uint32 {
+	return ep.mtu
+}
+
+func (ep *wireEndpoint) MaxHeaderLength() uint16 {
+	return 0
+}
+
+func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
+	return ""
+}
+
+func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+	return stack.CapabilityNone
+}
+
+func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+	ep.dispatcher = dispatcher
+}
+
+func (ep *wireEndpoint) IsAttached() bool {
+	return ep.dispatcher != nil
+}
+
+func (ep *wireEndpoint) Wait() {
+}
+
+func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
+	return header.ARPHardwareNone
+}
+
+func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
+}
+
+func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
+	for _, packetBuffer := range list.AsSlice() {
+		packetBuffer.IncRef()
+		ep.outbound <- packetBuffer
+	}
+	return list.Len(), nil
+}

+ 9 - 0
transport/wireguard/device_stack_stub.go

@@ -0,0 +1,9 @@
+//go:build no_gvisor
+
+package wireguard
+
+import "github.com/sagernet/sing-tun"
+
+func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (Device, error) {
+	return nil, tun.ErrGVisorNotIncluded
+}

+ 110 - 0
transport/wireguard/device_system.go

@@ -0,0 +1,110 @@
+package wireguard
+
+import (
+	"context"
+	"net"
+	"net/netip"
+	"os"
+
+	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
+	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing-tun"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+
+	wgTun "golang.zx2c4.com/wireguard/tun"
+)
+
+var _ Device = (*SystemDevice)(nil)
+
+type SystemDevice struct {
+	dialer N.Dialer
+	device tun.Tun
+	name   string
+	mtu    int
+	events chan wgTun.Event
+}
+
+/*func (w *SystemDevice) NewEndpoint() (stack.LinkEndpoint, error) {
+	gTun, isGTun := w.device.(tun.GVisorTun)
+	if !isGTun {
+		return nil, tun.ErrGVisorUnsupported
+	}
+	return gTun.NewEndpoint()
+}*/
+
+func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32) (*SystemDevice, error) {
+	var inet4Addresses []netip.Prefix
+	var inet6Addresses []netip.Prefix
+	for _, prefixes := range localPrefixes {
+		if prefixes.Addr().Is4() {
+			inet4Addresses = append(inet4Addresses, prefixes)
+		} else {
+			inet6Addresses = append(inet6Addresses, prefixes)
+		}
+	}
+	if interfaceName == "" {
+		interfaceName = tun.CalculateInterfaceName("wg")
+	}
+	tunInterface, err := tun.Open(tun.Options{
+		Name:         interfaceName,
+		Inet4Address: inet4Addresses,
+		Inet6Address: inet6Addresses,
+		MTU:          mtu,
+	})
+	if err != nil {
+		return nil, err
+	}
+	return &SystemDevice{
+		dialer.NewDefault(router, option.DialerOptions{
+			BindInterface: interfaceName,
+		}),
+		tunInterface, interfaceName, int(mtu), make(chan wgTun.Event),
+	}, nil
+}
+
+func (w *SystemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
+	return w.dialer.DialContext(ctx, network, destination)
+}
+
+func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
+	return w.dialer.ListenPacket(ctx, destination)
+}
+
+func (w *SystemDevice) Start() error {
+	w.events <- wgTun.EventUp
+	return nil
+}
+
+func (w *SystemDevice) File() *os.File {
+	return nil
+}
+
+func (w *SystemDevice) Read(bytes []byte, index int) (int, error) {
+	return w.device.Read(bytes[index-tunPacketOffset:])
+}
+
+func (w *SystemDevice) Write(bytes []byte, index int) (int, error) {
+	return w.device.Write(bytes[index:])
+}
+
+func (w *SystemDevice) Flush() error {
+	return nil
+}
+
+func (w *SystemDevice) MTU() (int, error) {
+	return w.mtu, nil
+}
+
+func (w *SystemDevice) Name() (string, error) {
+	return w.name, nil
+}
+
+func (w *SystemDevice) Events() chan wgTun.Event {
+	return w.events
+}
+
+func (w *SystemDevice) Close() error {
+	return w.device.Close()
+}

+ 37 - 0
transport/wireguard/endpoint.go

@@ -0,0 +1,37 @@
+package wireguard
+
+import (
+	"net/netip"
+
+	M "github.com/sagernet/sing/common/metadata"
+
+	"golang.zx2c4.com/wireguard/conn"
+)
+
+var _ conn.Endpoint = (*Endpoint)(nil)
+
+type Endpoint M.Socksaddr
+
+func (e Endpoint) ClearSrc() {
+}
+
+func (e Endpoint) SrcToString() string {
+	return ""
+}
+
+func (e Endpoint) DstToString() string {
+	return (M.Socksaddr)(e).String()
+}
+
+func (e Endpoint) DstToBytes() []byte {
+	b, _ := (M.Socksaddr)(e).AddrPort().MarshalBinary()
+	return b
+}
+
+func (e Endpoint) DstIP() netip.Addr {
+	return (M.Socksaddr)(e).Addr
+}
+
+func (e Endpoint) SrcIP() netip.Addr {
+	return netip.Addr{}
+}

+ 22 - 0
transport/wireguard/error.go

@@ -0,0 +1,22 @@
+package wireguard
+
+import "net"
+
+type wireError struct {
+	cause error
+}
+
+func (w *wireError) Error() string {
+	return w.cause.Error()
+}
+
+func (w *wireError) Timeout() bool {
+	if cause, causeNet := w.cause.(net.Error); causeNet {
+		return cause.Timeout()
+	}
+	return false
+}
+
+func (w *wireError) Temporary() bool {
+	return true
+}

+ 96 - 0
transport/wireguard/server_bind.go

@@ -0,0 +1,96 @@
+package wireguard
+
+import (
+	"io"
+
+	"github.com/sagernet/sing/common/buf"
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+
+	"golang.zx2c4.com/wireguard/conn"
+)
+
+var _ conn.Bind = (*ServerBind)(nil)
+
+type ServerBind struct {
+	inbound   chan serverPacket
+	done      chan struct{}
+	writeBack N.PacketWriter
+}
+
+func NewServerBind(writeBack N.PacketWriter) *ServerBind {
+	return &ServerBind{
+		inbound:   make(chan serverPacket, 256),
+		done:      make(chan struct{}),
+		writeBack: writeBack,
+	}
+}
+
+func (s *ServerBind) Abort() error {
+	select {
+	case <-s.done:
+		return io.ErrClosedPipe
+	default:
+		close(s.done)
+	}
+	return nil
+}
+
+type serverPacket struct {
+	buffer *buf.Buffer
+	source M.Socksaddr
+}
+
+func (s *ServerBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+	fns = []conn.ReceiveFunc{s.receive}
+	return
+}
+
+func (s *ServerBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
+	select {
+	case packet := <-s.inbound:
+		defer packet.buffer.Release()
+		n = copy(b, packet.buffer.Bytes())
+		ep = Endpoint(packet.source)
+		return
+	case <-s.done:
+		err = io.ErrClosedPipe
+		return
+	}
+}
+
+func (s *ServerBind) WriteIsThreadUnsafe() {
+}
+
+func (s *ServerBind) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
+	select {
+	case s.inbound <- serverPacket{
+		buffer: buffer,
+		source: destination,
+	}:
+		return nil
+	case <-s.done:
+		return io.ErrClosedPipe
+	}
+}
+
+func (s *ServerBind) Close() error {
+	return nil
+}
+
+func (s *ServerBind) SetMark(mark uint32) error {
+	return nil
+}
+
+func (s *ServerBind) Send(b []byte, ep conn.Endpoint) error {
+	return s.writeBack.WritePacket(buf.As(b), M.Socksaddr(ep.(Endpoint)))
+}
+
+func (s *ServerBind) ParseEndpoint(addr string) (conn.Endpoint, error) {
+	destination := M.ParseSocksaddr(addr)
+	if !destination.IsValid() || destination.Port == 0 {
+		return nil, E.New("invalid endpoint: ", addr)
+	}
+	return Endpoint(destination), nil
+}

+ 3 - 0
transport/wireguard/tun_darwin.go

@@ -0,0 +1,3 @@
+package wireguard
+
+const tunPacketOffset = 4

+ 5 - 0
transport/wireguard/tun_nondarwin.go

@@ -0,0 +1,5 @@
+//go:build !darwin
+
+package wireguard
+
+const tunPacketOffset = 0