Răsfoiți Sursa

Add multi-peer support for wireguard outbound

世界 2 ani în urmă
părinte
comite
e168de79c7

+ 31 - 3
docs/configuration/outbound/wireguard.md

@@ -13,6 +13,18 @@
     "10.0.0.2/32"
   ],
   "private_key": "YNXtAzepDqRv9H52osJVDQnznT5AM11eCK3ESpwSt04=",
+  "peers": [
+    {
+      "server": "127.0.0.1",
+      "server_port": 1080,
+      "public_key": "Z1XXLsKYkYxuiYjJIkRvtIKFepCYHTgON+GwPq7SOV4=",
+      "pre_shared_key": "31aIhAPwktDGpH4JDhA8GNvjFXEf/a6+UaQRyOAiyfM=",
+      "allowed_ips": [
+        "0.0.0.0/0"
+      ],
+      "reserved": [0, 0, 0]
+    }
+  ],
   "peer_public_key": "Z1XXLsKYkYxuiYjJIkRvtIKFepCYHTgON+GwPq7SOV4=",
   "pre_shared_key": "31aIhAPwktDGpH4JDhA8GNvjFXEf/a6+UaQRyOAiyfM=",
   "reserved": [0, 0, 0],
@@ -36,13 +48,13 @@
 
 #### server
 
-==Required==
+==Required if multi-peer disabled==
 
 The server address.
 
 #### server_port
 
-==Required==
+==Required if multi-peer disabled==
 
 The server port.
 
@@ -75,9 +87,25 @@ wg genkey
 echo "private key" || wg pubkey
 ```
 
+#### peers
+
+Multi-peer support. 
+
+If enabled, `server, server_port, peer_public_key, pre_shared_key` will be ignored.
+
+#### peers.allowed_ips
+
+WireGuard allowed IPs.
+
+#### peers.reserved
+
+WireGuard reserved field bytes.
+
+`$outbound.reserved` will be used if empty.
+
 #### peer_public_key
 
-==Required==
+==Required if multi-peer disabled==
 
 WireGuard peer public key.
 

+ 16 - 7
option/wireguard.go

@@ -2,15 +2,24 @@ package option
 
 type WireGuardOutboundOptions struct {
 	DialerOptions
-	ServerOptions
 	SystemInterface bool                   `json:"system_interface,omitempty"`
 	InterfaceName   string                 `json:"interface_name,omitempty"`
 	LocalAddress    Listable[ListenPrefix] `json:"local_address"`
 	PrivateKey      string                 `json:"private_key"`
-	PeerPublicKey   string                 `json:"peer_public_key"`
-	PreSharedKey    string                 `json:"pre_shared_key,omitempty"`
-	Reserved        []uint8                `json:"reserved,omitempty"`
-	Workers         int                    `json:"workers,omitempty"`
-	MTU             uint32                 `json:"mtu,omitempty"`
-	Network         NetworkList            `json:"network,omitempty"`
+	Peers           []WireGuardPeer        `json:"peers,omitempty"`
+	ServerOptions
+	PeerPublicKey string      `json:"peer_public_key"`
+	PreSharedKey  string      `json:"pre_shared_key,omitempty"`
+	Reserved      []uint8     `json:"reserved,omitempty"`
+	Workers       int         `json:"workers,omitempty"`
+	MTU           uint32      `json:"mtu,omitempty"`
+	Network       NetworkList `json:"network,omitempty"`
+}
+
+type WireGuardPeer struct {
+	ServerOptions
+	PublicKey    string           `json:"public_key,omitempty"`
+	PreSharedKey string           `json:"pre_shared_key,omitempty"`
+	AllowedIPs   Listable[string] `json:"allowed_ips,omitempty"`
+	Reserved     []uint8          `json:"reserved,omitempty"`
 }

+ 82 - 33
outbound/wireguard.go

@@ -54,13 +54,22 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
 		}
 		copy(reserved[:], options.Reserved)
 	}
-	peerAddr := options.ServerOptions.Build()
-	outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), peerAddr, reserved)
+	var isConnect bool
+	var connectAddr M.Socksaddr
+	if len(options.Peers) < 2 {
+		isConnect = true
+		if len(options.Peers) == 1 {
+			connectAddr = options.Peers[0].ServerOptions.Build()
+		} else {
+			connectAddr = options.ServerOptions.Build()
+		}
+	}
+	outbound.bind = wireguard.NewClientBind(ctx, dialer.New(router, options.DialerOptions), isConnect, connectAddr, reserved)
 	localPrefixes := common.Map(options.LocalAddress, option.ListenPrefix.Build)
 	if len(localPrefixes) == 0 {
 		return nil, E.New("missing local address")
 	}
-	var privateKey, peerPublicKey, preSharedKey string
+	var privateKey string
 	{
 		bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
 		if err != nil {
@@ -68,39 +77,79 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
 		}
 		privateKey = hex.EncodeToString(bytes)
 	}
-	{
-		bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
-		if err != nil {
-			return nil, E.Cause(err, "decode peer public key")
+	ipcConf := "private_key=" + privateKey
+	if len(options.Peers) > 0 {
+		for i, peer := range options.Peers {
+			var peerPublicKey, preSharedKey string
+			{
+				bytes, err := base64.StdEncoding.DecodeString(peer.PublicKey)
+				if err != nil {
+					return nil, E.Cause(err, "decode public key for peer ", i)
+				}
+				peerPublicKey = hex.EncodeToString(bytes)
+			}
+			if peer.PreSharedKey != "" {
+				bytes, err := base64.StdEncoding.DecodeString(peer.PreSharedKey)
+				if err != nil {
+					return nil, E.Cause(err, "decode pre shared key for peer ", i)
+				}
+				preSharedKey = hex.EncodeToString(bytes)
+			}
+			destination := peer.ServerOptions.Build()
+			ipcConf += "\npublic_key=" + peerPublicKey
+			ipcConf += "\nendpoint=" + destination.String()
+			if preSharedKey != "" {
+				ipcConf += "\npreshared_key=" + preSharedKey
+			}
+			if len(peer.AllowedIPs) == 0 {
+				return nil, E.New("missing allowed_ips for peer ", i)
+			}
+			for _, allowedIP := range peer.AllowedIPs {
+				ipcConf += "\nallowed_ip=" + allowedIP
+			}
+			if len(peer.Reserved) > 0 {
+				if len(peer.Reserved) != 3 {
+					return nil, E.New("invalid reserved value for peer ", i, ", required 3 bytes, got ", len(peer.Reserved))
+				}
+				copy(reserved[:], options.Reserved)
+				outbound.bind.SetReservedForEndpoint(destination, reserved)
+			}
 		}
-		peerPublicKey = hex.EncodeToString(bytes)
-	}
-	if options.PreSharedKey != "" {
-		bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
-		if err != nil {
-			return nil, E.Cause(err, "decode pre shared key")
+	} else {
+		var peerPublicKey, preSharedKey string
+		{
+			bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
+			if err != nil {
+				return nil, E.Cause(err, "decode peer public key")
+			}
+			peerPublicKey = hex.EncodeToString(bytes)
 		}
-		preSharedKey = hex.EncodeToString(bytes)
-	}
-	ipcConf := "private_key=" + privateKey
-	ipcConf += "\npublic_key=" + peerPublicKey
-	ipcConf += "\nendpoint=" + peerAddr.String()
-	if preSharedKey != "" {
-		ipcConf += "\npreshared_key=" + preSharedKey
-	}
-	var has4, has6 bool
-	for _, address := range localPrefixes {
-		if address.Addr().Is4() {
-			has4 = true
-		} else {
-			has6 = true
+		if options.PreSharedKey != "" {
+			bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
+			if err != nil {
+				return nil, E.Cause(err, "decode pre shared key")
+			}
+			preSharedKey = hex.EncodeToString(bytes)
+		}
+		ipcConf += "\npublic_key=" + peerPublicKey
+		ipcConf += "\nendpoint=" + options.ServerOptions.Build().String()
+		if preSharedKey != "" {
+			ipcConf += "\npreshared_key=" + preSharedKey
+		}
+		var has4, has6 bool
+		for _, address := range localPrefixes {
+			if address.Addr().Is4() {
+				has4 = true
+			} else {
+				has6 = true
+			}
+		}
+		if has4 {
+			ipcConf += "\nallowed_ip=0.0.0.0/0"
+		}
+		if has6 {
+			ipcConf += "\nallowed_ip=::/0"
 		}
-	}
-	if has4 {
-		ipcConf += "\nallowed_ip=0.0.0.0/0"
-	}
-	if has6 {
-		ipcConf += "\nallowed_ip=::/0"
 	}
 	mtu := options.MTU
 	if mtu == 0 {

+ 63 - 32
transport/wireguard/client_bind.go

@@ -3,9 +3,12 @@ package wireguard
 import (
 	"context"
 	"net"
+	"net/netip"
 	"sync"
 
 	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/bufio"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/wireguard-go/conn"
@@ -14,24 +17,32 @@ import (
 var _ conn.Bind = (*ClientBind)(nil)
 
 type ClientBind struct {
-	ctx        context.Context
-	dialer     N.Dialer
-	peerAddr   M.Socksaddr
-	reserved   [3]uint8
-	connAccess sync.Mutex
-	conn       *wireConn
-	done       chan struct{}
+	ctx                 context.Context
+	dialer              N.Dialer
+	reservedForEndpoint map[M.Socksaddr][3]uint8
+	connAccess          sync.Mutex
+	conn                *wireConn
+	done                chan struct{}
+	isConnect           bool
+	connectAddr         M.Socksaddr
+	reserved            [3]uint8
 }
 
-func NewClientBind(ctx context.Context, dialer N.Dialer, peerAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
+func NewClientBind(ctx context.Context, dialer N.Dialer, isConnect bool, connectAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
 	return &ClientBind{
-		ctx:      ctx,
-		dialer:   dialer,
-		peerAddr: peerAddr,
-		reserved: reserved,
+		ctx:                 ctx,
+		dialer:              dialer,
+		reservedForEndpoint: make(map[M.Socksaddr][3]uint8),
+		isConnect:           isConnect,
+		connectAddr:         connectAddr,
+		reserved:            reserved,
 	}
 }
 
+func (c *ClientBind) SetReservedForEndpoint(destination M.Socksaddr, reserved [3]byte) {
+	c.reservedForEndpoint[destination] = reserved
+}
+
 func (c *ClientBind) connect() (*wireConn, error) {
 	serverConn := c.conn
 	if serverConn != nil {
@@ -53,13 +64,27 @@ func (c *ClientBind) connect() (*wireConn, error) {
 			return serverConn, nil
 		}
 	}
-	udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.peerAddr)
-	if err != nil {
-		return nil, &wireError{err}
-	}
-	c.conn = &wireConn{
-		Conn: udpConn,
-		done: make(chan struct{}),
+	if c.isConnect {
+		udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, c.connectAddr)
+		if err != nil {
+			return nil, &wireError{err}
+		}
+		c.conn = &wireConn{
+			NetPacketConn: &bufio.UnbindPacketConn{
+				ExtendedConn: bufio.NewExtendedConn(udpConn),
+				Addr:         c.connectAddr,
+			},
+			done: make(chan struct{}),
+		}
+	} else {
+		udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
+		if err != nil {
+			return nil, &wireError{err}
+		}
+		c.conn = &wireConn{
+			NetPacketConn: bufio.NewPacketConn(udpConn),
+			done:          make(chan struct{}),
+		}
 	}
 	return c.conn, nil
 }
@@ -80,7 +105,8 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
 		err = &wireError{err}
 		return
 	}
-	n, err = udpConn.Read(b)
+	buffer := buf.With(b)
+	destination, err := udpConn.ReadPacket(buffer)
 	if err != nil {
 		udpConn.Close()
 		select {
@@ -90,12 +116,16 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
 		}
 		return
 	}
+	n = buffer.Len()
+	if buffer.Start() > 0 {
+		copy(b, buffer.Bytes())
+	}
 	if n > 3 {
 		b[1] = 0
 		b[2] = 0
 		b[3] = 0
 	}
-	ep = Endpoint(c.peerAddr)
+	ep = Endpoint(destination)
 	return
 }
 
@@ -127,12 +157,17 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
 	if err != nil {
 		return err
 	}
+	destination := M.Socksaddr(ep.(Endpoint))
 	if len(b) > 3 {
-		b[1] = c.reserved[0]
-		b[2] = c.reserved[1]
-		b[3] = c.reserved[2]
+		reserved, loaded := c.reservedForEndpoint[destination]
+		if !loaded {
+			reserved = c.reserved
+		}
+		b[1] = reserved[0]
+		b[2] = reserved[1]
+		b[3] = reserved[2]
 	}
-	_, err = udpConn.Write(b)
+	err = udpConn.WritePacket(buf.As(b), destination)
 	if err != nil {
 		udpConn.Close()
 	}
@@ -140,15 +175,11 @@ func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
 }
 
 func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
-	return Endpoint(c.peerAddr), nil
-}
-
-func (c *ClientBind) Endpoint() conn.Endpoint {
-	return Endpoint(c.peerAddr)
+	return Endpoint(M.ParseSocksaddr(s)), nil
 }
 
 type wireConn struct {
-	net.Conn
+	N.NetPacketConn
 	access sync.Mutex
 	done   chan struct{}
 }
@@ -161,7 +192,7 @@ func (w *wireConn) Close() error {
 		return net.ErrClosed
 	default:
 	}
-	w.Conn.Close()
+	w.NetPacketConn.Close()
 	close(w.done)
 	return nil
 }