ソースを参照

Update wireguard-go

世界 2 年 前
コミット
b6068cea6b

+ 1 - 1
go.mod

@@ -35,7 +35,7 @@ require (
 	github.com/sagernet/tfo-go v0.0.0-20230303015439-ffcfd8c41cf9
 	github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2
 	github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e
-	github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c
+	github.com/sagernet/wireguard-go v0.0.0-20230420044414-a7bac1754e77
 	github.com/spf13/cobra v1.7.0
 	github.com/stretchr/testify v1.8.3
 	go.etcd.io/bbolt v1.3.7

+ 2 - 2
go.sum

@@ -131,8 +131,8 @@ github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2 h1:kDUqhc9Vsk5HJuhfI
 github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2/go.mod h1:JKQMZq/O2qnZjdrt+B57olmfgEmLtY9iiSIEYtWvoSM=
 github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e h1:7uw2njHFGE+VpWamge6o56j2RWk4omF6uLKKxMmcWvs=
 github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e/go.mod h1:45TUl8+gH4SIKr4ykREbxKWTxkDlSzFENzctB1dVRRY=
-github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c h1:vK2wyt9aWYHHvNLWniwijBu/n4pySypiKRhN32u/JGo=
-github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c/go.mod h1:euOmN6O5kk9dQmgSS8Df4psAl3TCjxOz0NW60EWkSaI=
+github.com/sagernet/wireguard-go v0.0.0-20230420044414-a7bac1754e77 h1:g6QtRWQ2dKX7EQP++1JLNtw4C2TNxd4/ov8YUpOPOSo=
+github.com/sagernet/wireguard-go v0.0.0-20230420044414-a7bac1754e77/go.mod h1:pJDdXzZIwJ+2vmnT0TKzmf8meeum+e2mTDSehw79eE0=
 github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
 github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
 github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=

+ 27 - 16
transport/wireguard/client_bind.go

@@ -101,7 +101,7 @@ func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint1
 	return []conn.ReceiveFunc{c.receive}, 0, nil
 }
 
-func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
+func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) {
 	udpConn, err := c.connect()
 	if err != nil {
 		select {
@@ -113,22 +113,26 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
 		err = nil
 		return
 	}
-	n, addr, err := udpConn.ReadFrom(b)
+	n, addr, err := udpConn.ReadFrom(packets[0])
 	if err != nil {
 		udpConn.Close()
 		select {
 		case <-c.done:
 		default:
 			c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet"))
+			err = nil
 		}
 		return
 	}
+	sizes[0] = n
 	if n > 3 {
+		b := packets[0]
 		b[1] = 0
 		b[2] = 0
 		b[3] = 0
 	}
-	ep = Endpoint(M.SocksaddrFromNet(addr))
+	eps[0] = Endpoint(M.SocksaddrFromNet(addr))
+	count = 1
 	return
 }
 
@@ -155,32 +159,39 @@ func (c *ClientBind) SetMark(mark uint32) error {
 	return nil
 }
 
-func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
+func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
 	udpConn, err := c.connect()
 	if err != nil {
 		return err
 	}
 	destination := M.Socksaddr(ep.(Endpoint))
-	if len(b) > 3 {
-		reserved, loaded := c.reservedForEndpoint[destination]
-		if !loaded {
-			reserved = c.reserved
+	for _, b := range bufs {
+		if len(b) > 3 {
+			reserved, loaded := c.reservedForEndpoint[destination]
+			if !loaded {
+				reserved = c.reserved
+			}
+			b[1] = reserved[0]
+			b[2] = reserved[1]
+			b[3] = reserved[2]
+		}
+		_, err = udpConn.WriteTo(b, destination)
+		if err != nil {
+			udpConn.Close()
+			return err
 		}
-		b[1] = reserved[0]
-		b[2] = reserved[1]
-		b[3] = reserved[2]
-	}
-	_, err = udpConn.WriteTo(b, destination)
-	if err != nil {
-		udpConn.Close()
 	}
-	return err
+	return nil
 }
 
 func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
 	return Endpoint(M.ParseSocksaddr(s)), nil
 }
 
+func (c *ClientBind) BatchSize() int {
+	return 1
+}
+
 type wireConn struct {
 	net.PacketConn
 	access sync.Mutex

+ 60 - 34
transport/wireguard/device_stack.go

@@ -8,10 +8,11 @@ import (
 	"net/netip"
 	"os"
 
+	"github.com/sagernet/sing/common/buf"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
-	"github.com/sagernet/wireguard-go/tun"
+	wgTun "github.com/sagernet/wireguard-go/tun"
 
 	"gvisor.dev/gvisor/pkg/bufferv2"
 	"gvisor.dev/gvisor/pkg/tcpip"
@@ -30,14 +31,15 @@ var _ Device = (*StackDevice)(nil)
 const defaultNIC tcpip.NICID = 1
 
 type StackDevice struct {
-	stack      *stack.Stack
-	mtu        uint32
-	events     chan tun.Event
-	outbound   chan *stack.PacketBuffer
-	done       chan struct{}
-	dispatcher stack.NetworkDispatcher
-	addr4      tcpip.Address
-	addr6      tcpip.Address
+	stack          *stack.Stack
+	mtu            uint32
+	events         chan wgTun.Event
+	outbound       chan *stack.PacketBuffer
+	packetOutbound chan *buf.Buffer
+	done           chan struct{}
+	dispatcher     stack.NetworkDispatcher
+	addr4          tcpip.Address
+	addr6          tcpip.Address
 }
 
 func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
@@ -47,11 +49,12 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er
 		HandleLocal:        true,
 	})
 	tunDevice := &StackDevice{
-		stack:    ipStack,
-		mtu:      mtu,
-		events:   make(chan tun.Event, 1),
-		outbound: make(chan *stack.PacketBuffer, 256),
-		done:     make(chan struct{}),
+		stack:          ipStack,
+		mtu:            mtu,
+		events:         make(chan wgTun.Event, 1),
+		outbound:       make(chan *stack.PacketBuffer, 256),
+		packetOutbound: make(chan *buf.Buffer, 256),
+		done:           make(chan struct{}),
 	}
 	err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
 	if err != nil {
@@ -144,8 +147,16 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
 	return udpConn, nil
 }
 
+func (w *StackDevice) Inet4Address() netip.Addr {
+	return M.AddrFromIP(net.IP(w.addr4))
+}
+
+func (w *StackDevice) Inet6Address() netip.Addr {
+	return M.AddrFromIP(net.IP(w.addr6))
+}
+
 func (w *StackDevice) Start() error {
-	w.events <- tun.EventUp
+	w.events <- wgTun.EventUp
 	return nil
 }
 
@@ -153,41 +164,52 @@ func (w *StackDevice) File() *os.File {
 	return nil
 }
 
-func (w *StackDevice) Read(p []byte, offset int) (n int, err error) {
+func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
 	select {
 	case packetBuffer, ok := <-w.outbound:
 		if !ok {
 			return 0, os.ErrClosed
 		}
 		defer packetBuffer.DecRef()
+		p := bufs[0]
 		p = p[offset:]
+		n := 0
 		for _, slice := range packetBuffer.AsSlices() {
 			n += copy(p[n:], slice)
 		}
+		sizes[0] = n
+		count = 1
+		return
+	case packet := <-w.packetOutbound:
+		defer packet.Release()
+		sizes[0] = copy(bufs[0][offset:], packet.Bytes())
+		count = 1
 		return
 	case <-w.done:
 		return 0, os.ErrClosed
 	}
 }
 
-func (w *StackDevice) Write(p []byte, offset int) (n int, err error) {
-	p = p[offset:]
-	if len(p) == 0 {
-		return
-	}
-	var networkProtocol tcpip.NetworkProtocolNumber
-	switch header.IPVersion(p) {
-	case header.IPv4Version:
-		networkProtocol = header.IPv4ProtocolNumber
-	case header.IPv6Version:
-		networkProtocol = header.IPv6ProtocolNumber
+func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
+	for _, b := range bufs {
+		b = b[offset:]
+		if len(b) == 0 {
+			continue
+		}
+		var networkProtocol tcpip.NetworkProtocolNumber
+		switch header.IPVersion(b) {
+		case header.IPv4Version:
+			networkProtocol = header.IPv4ProtocolNumber
+		case header.IPv6Version:
+			networkProtocol = header.IPv6ProtocolNumber
+		}
+		packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
+			Payload: bufferv2.MakeWithData(b),
+		})
+		w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
+		packetBuffer.DecRef()
+		count++
 	}
-	packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
-		Payload: bufferv2.MakeWithData(p),
-	})
-	defer packetBuffer.DecRef()
-	w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
-	n = len(p)
 	return
 }
 
@@ -203,7 +225,7 @@ func (w *StackDevice) Name() (string, error) {
 	return "sing-box", nil
 }
 
-func (w *StackDevice) Events() chan tun.Event {
+func (w *StackDevice) Events() <-chan wgTun.Event {
 	return w.events
 }
 
@@ -222,6 +244,10 @@ func (w *StackDevice) Close() error {
 	return nil
 }
 
+func (w *StackDevice) BatchSize() int {
+	return 1
+}
+
 var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
 
 type wireEndpoint StackDevice

+ 45 - 15
transport/wireguard/device_system.go

@@ -23,16 +23,10 @@ type SystemDevice struct {
 	name   string
 	mtu    int
 	events chan wgTun.Event
+	addr4  netip.Addr
+	addr6  netip.Addr
 }
 
-/*func (w *SystemDevice) NewEndpoint() (stack.LinkEndpoint, error) {
-	gTun, isGTun := w.device.(tun.GVisorTun)
-	if !isGTun {
-		return nil, tun.ErrGVisorUnsupported
-	}
-	return gTun.NewEndpoint()
-}*/
-
 func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32) (*SystemDevice, error) {
 	var inet4Addresses []netip.Prefix
 	var inet6Addresses []netip.Prefix
@@ -55,11 +49,24 @@ func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes
 	if err != nil {
 		return nil, err
 	}
+	var inet4Address netip.Addr
+	var inet6Address netip.Addr
+	if len(inet4Addresses) > 0 {
+		inet4Address = inet4Addresses[0].Addr()
+	}
+	if len(inet6Addresses) > 0 {
+		inet6Address = inet6Addresses[0].Addr()
+	}
 	return &SystemDevice{
-		dialer.NewDefault(router, option.DialerOptions{
+		dialer: dialer.NewDefault(router, option.DialerOptions{
 			BindInterface: interfaceName,
 		}),
-		tunInterface, interfaceName, int(mtu), make(chan wgTun.Event),
+		device: tunInterface,
+		name:   interfaceName,
+		mtu:    int(mtu),
+		events: make(chan wgTun.Event),
+		addr4:  inet4Address,
+		addr6:  inet6Address,
 	}, nil
 }
 
@@ -71,6 +78,14 @@ func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr
 	return w.dialer.ListenPacket(ctx, destination)
 }
 
+func (w *SystemDevice) Inet4Address() netip.Addr {
+	return w.addr4
+}
+
+func (w *SystemDevice) Inet6Address() netip.Addr {
+	return w.addr6
+}
+
 func (w *SystemDevice) Start() error {
 	w.events <- wgTun.EventUp
 	return nil
@@ -80,12 +95,23 @@ func (w *SystemDevice) File() *os.File {
 	return nil
 }
 
-func (w *SystemDevice) Read(bytes []byte, index int) (int, error) {
-	return w.device.Read(bytes[index-tun.PacketOffset:])
+func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
+	sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:])
+	if err == nil {
+		count = 1
+	}
+	return
 }
 
-func (w *SystemDevice) Write(bytes []byte, index int) (int, error) {
-	return w.device.Write(bytes[index:])
+func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
+	for _, b := range bufs {
+		_, err = w.device.Write(b[offset:])
+		if err != nil {
+			return
+		}
+		count++
+	}
+	return
 }
 
 func (w *SystemDevice) Flush() error {
@@ -100,10 +126,14 @@ func (w *SystemDevice) Name() (string, error) {
 	return w.name, nil
 }
 
-func (w *SystemDevice) Events() chan wgTun.Event {
+func (w *SystemDevice) Events() <-chan wgTun.Event {
 	return w.events
 }
 
 func (w *SystemDevice) Close() error {
 	return w.device.Close()
 }
+
+func (w *SystemDevice) BatchSize() int {
+	return 1
+}

+ 0 - 95
transport/wireguard/server_bind.go

@@ -1,95 +0,0 @@
-package wireguard
-
-import (
-	"io"
-
-	"github.com/sagernet/sing/common/buf"
-	E "github.com/sagernet/sing/common/exceptions"
-	M "github.com/sagernet/sing/common/metadata"
-	N "github.com/sagernet/sing/common/network"
-	"github.com/sagernet/wireguard-go/conn"
-)
-
-var _ conn.Bind = (*ServerBind)(nil)
-
-type ServerBind struct {
-	inbound   chan serverPacket
-	done      chan struct{}
-	writeBack N.PacketWriter
-}
-
-func NewServerBind(writeBack N.PacketWriter) *ServerBind {
-	return &ServerBind{
-		inbound:   make(chan serverPacket, 256),
-		done:      make(chan struct{}),
-		writeBack: writeBack,
-	}
-}
-
-func (s *ServerBind) Abort() error {
-	select {
-	case <-s.done:
-		return io.ErrClosedPipe
-	default:
-		close(s.done)
-	}
-	return nil
-}
-
-type serverPacket struct {
-	buffer *buf.Buffer
-	source M.Socksaddr
-}
-
-func (s *ServerBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
-	fns = []conn.ReceiveFunc{s.receive}
-	return
-}
-
-func (s *ServerBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
-	select {
-	case packet := <-s.inbound:
-		defer packet.buffer.Release()
-		n = copy(b, packet.buffer.Bytes())
-		ep = Endpoint(packet.source)
-		return
-	case <-s.done:
-		err = io.ErrClosedPipe
-		return
-	}
-}
-
-func (s *ServerBind) WriteIsThreadUnsafe() {
-}
-
-func (s *ServerBind) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
-	select {
-	case s.inbound <- serverPacket{
-		buffer: buffer,
-		source: destination,
-	}:
-		return nil
-	case <-s.done:
-		return io.ErrClosedPipe
-	}
-}
-
-func (s *ServerBind) Close() error {
-	return nil
-}
-
-func (s *ServerBind) SetMark(mark uint32) error {
-	return nil
-}
-
-func (s *ServerBind) Send(b []byte, ep conn.Endpoint) error {
-	return s.writeBack.WritePacket(buf.As(b), M.Socksaddr(ep.(Endpoint)))
-}
-
-func (s *ServerBind) ParseEndpoint(addr string) (conn.Endpoint, error) {
-	destination := M.ParseSocksaddr(addr)
-	if !destination.IsValid() || destination.Port == 0 {
-		return nil, E.New("invalid endpoint: ", addr)
-	}
-	return Endpoint(destination), nil
-}