|
@@ -8,6 +8,7 @@ import (
|
|
|
"encoding/hex"
|
|
|
"fmt"
|
|
|
"net"
|
|
|
+ "net/netip"
|
|
|
"strings"
|
|
|
|
|
|
"github.com/sagernet/sing-box/adapter"
|
|
@@ -18,10 +19,12 @@ import (
|
|
|
"github.com/sagernet/sing-box/transport/wireguard"
|
|
|
"github.com/sagernet/sing-dns"
|
|
|
"github.com/sagernet/sing-tun"
|
|
|
- "github.com/sagernet/sing/common/debug"
|
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
|
N "github.com/sagernet/sing/common/network"
|
|
|
+ "github.com/sagernet/sing/common/x/list"
|
|
|
+ "github.com/sagernet/sing/service/pause"
|
|
|
+ "github.com/sagernet/wireguard-go/conn"
|
|
|
"github.com/sagernet/wireguard-go/device"
|
|
|
)
|
|
|
|
|
@@ -32,9 +35,18 @@ var (
|
|
|
|
|
|
type WireGuard struct {
|
|
|
myOutboundAdapter
|
|
|
- bind *wireguard.ClientBind
|
|
|
- device *device.Device
|
|
|
- tunDevice wireguard.Device
|
|
|
+ ctx context.Context
|
|
|
+ workers int
|
|
|
+ peers []wireguard.PeerConfig
|
|
|
+ useStdNetBind bool
|
|
|
+ listener N.Dialer
|
|
|
+ ipcConf string
|
|
|
+
|
|
|
+ pauseManager pause.Manager
|
|
|
+ pauseCallback *list.Element[pause.Callback]
|
|
|
+ bind conn.Bind
|
|
|
+ device *device.Device
|
|
|
+ tunDevice wireguard.Device
|
|
|
}
|
|
|
|
|
|
func NewWireGuard(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (*WireGuard, error) {
|
|
@@ -47,32 +59,30 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
|
|
|
tag: tag,
|
|
|
dependencies: withDialerDependency(options.DialerOptions),
|
|
|
},
|
|
|
+ ctx: ctx,
|
|
|
+ workers: options.Workers,
|
|
|
+ pauseManager: pause.ManagerFromContext(ctx),
|
|
|
}
|
|
|
- var reserved [3]uint8
|
|
|
- if len(options.Reserved) > 0 {
|
|
|
- if len(options.Reserved) != 3 {
|
|
|
- return nil, E.New("invalid reserved value, required 3 bytes, got ", len(options.Reserved))
|
|
|
- }
|
|
|
- copy(reserved[:], options.Reserved)
|
|
|
- }
|
|
|
- var isConnect bool
|
|
|
- var connectAddr M.Socksaddr
|
|
|
- if len(options.Peers) < 2 {
|
|
|
- isConnect = true
|
|
|
- if len(options.Peers) == 1 {
|
|
|
- connectAddr = options.Peers[0].ServerOptions.Build()
|
|
|
- } else {
|
|
|
- connectAddr = options.ServerOptions.Build()
|
|
|
- }
|
|
|
- }
|
|
|
- outboundDialer, err := dialer.New(router, options.DialerOptions)
|
|
|
+ peers, err := wireguard.ParsePeers(options)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
- outbound.bind = wireguard.NewClientBind(ctx, outbound, outboundDialer, isConnect, connectAddr, reserved)
|
|
|
+ outbound.peers = peers
|
|
|
if len(options.LocalAddress) == 0 {
|
|
|
return nil, E.New("missing local address")
|
|
|
}
|
|
|
+ if options.GSO {
|
|
|
+ if options.GSO && options.Detour != "" {
|
|
|
+ return nil, E.New("gso is conflict with detour")
|
|
|
+ }
|
|
|
+ options.IsWireGuardListener = true
|
|
|
+ outbound.useStdNetBind = true
|
|
|
+ }
|
|
|
+ listener, err := dialer.New(router, options.DialerOptions)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ outbound.listener = listener
|
|
|
var privateKey string
|
|
|
{
|
|
|
bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
|
|
@@ -81,80 +91,7 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
|
|
|
}
|
|
|
privateKey = hex.EncodeToString(bytes)
|
|
|
}
|
|
|
- ipcConf := "private_key=" + privateKey
|
|
|
- if len(options.Peers) > 0 {
|
|
|
- for i, peer := range options.Peers {
|
|
|
- var peerPublicKey, preSharedKey string
|
|
|
- {
|
|
|
- bytes, err := base64.StdEncoding.DecodeString(peer.PublicKey)
|
|
|
- if err != nil {
|
|
|
- return nil, E.Cause(err, "decode public key for peer ", i)
|
|
|
- }
|
|
|
- peerPublicKey = hex.EncodeToString(bytes)
|
|
|
- }
|
|
|
- if peer.PreSharedKey != "" {
|
|
|
- bytes, err := base64.StdEncoding.DecodeString(peer.PreSharedKey)
|
|
|
- if err != nil {
|
|
|
- return nil, E.Cause(err, "decode pre shared key for peer ", i)
|
|
|
- }
|
|
|
- preSharedKey = hex.EncodeToString(bytes)
|
|
|
- }
|
|
|
- destination := peer.ServerOptions.Build()
|
|
|
- ipcConf += "\npublic_key=" + peerPublicKey
|
|
|
- ipcConf += "\nendpoint=" + destination.String()
|
|
|
- if preSharedKey != "" {
|
|
|
- ipcConf += "\npreshared_key=" + preSharedKey
|
|
|
- }
|
|
|
- if len(peer.AllowedIPs) == 0 {
|
|
|
- return nil, E.New("missing allowed_ips for peer ", i)
|
|
|
- }
|
|
|
- for _, allowedIP := range peer.AllowedIPs {
|
|
|
- ipcConf += "\nallowed_ip=" + allowedIP
|
|
|
- }
|
|
|
- if len(peer.Reserved) > 0 {
|
|
|
- if len(peer.Reserved) != 3 {
|
|
|
- return nil, E.New("invalid reserved value for peer ", i, ", required 3 bytes, got ", len(peer.Reserved))
|
|
|
- }
|
|
|
- copy(reserved[:], options.Reserved)
|
|
|
- outbound.bind.SetReservedForEndpoint(destination, reserved)
|
|
|
- }
|
|
|
- }
|
|
|
- } else {
|
|
|
- var peerPublicKey, preSharedKey string
|
|
|
- {
|
|
|
- bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
|
|
|
- if err != nil {
|
|
|
- return nil, E.Cause(err, "decode peer public key")
|
|
|
- }
|
|
|
- peerPublicKey = hex.EncodeToString(bytes)
|
|
|
- }
|
|
|
- if options.PreSharedKey != "" {
|
|
|
- bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
|
|
|
- if err != nil {
|
|
|
- return nil, E.Cause(err, "decode pre shared key")
|
|
|
- }
|
|
|
- preSharedKey = hex.EncodeToString(bytes)
|
|
|
- }
|
|
|
- ipcConf += "\npublic_key=" + peerPublicKey
|
|
|
- ipcConf += "\nendpoint=" + options.ServerOptions.Build().String()
|
|
|
- if preSharedKey != "" {
|
|
|
- ipcConf += "\npreshared_key=" + preSharedKey
|
|
|
- }
|
|
|
- var has4, has6 bool
|
|
|
- for _, address := range options.LocalAddress {
|
|
|
- if address.Addr().Is4() {
|
|
|
- has4 = true
|
|
|
- } else {
|
|
|
- has6 = true
|
|
|
- }
|
|
|
- }
|
|
|
- if has4 {
|
|
|
- ipcConf += "\nallowed_ip=0.0.0.0/0"
|
|
|
- }
|
|
|
- if has6 {
|
|
|
- ipcConf += "\nallowed_ip=::/0"
|
|
|
- }
|
|
|
- }
|
|
|
+ outbound.ipcConf = "private_key=" + privateKey
|
|
|
mtu := options.MTU
|
|
|
if mtu == 0 {
|
|
|
mtu = 1408
|
|
@@ -163,36 +100,83 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context
|
|
|
if !options.SystemInterface && tun.WithGVisor {
|
|
|
wireTunDevice, err = wireguard.NewStackDevice(options.LocalAddress, mtu)
|
|
|
} else {
|
|
|
- wireTunDevice, err = wireguard.NewSystemDevice(router, options.InterfaceName, options.LocalAddress, mtu)
|
|
|
+ wireTunDevice, err = wireguard.NewSystemDevice(router, options.InterfaceName, options.LocalAddress, mtu, options.GSO)
|
|
|
}
|
|
|
if err != nil {
|
|
|
return nil, E.Cause(err, "create WireGuard device")
|
|
|
}
|
|
|
- wgDevice := device.NewDevice(ctx, wireTunDevice, outbound.bind, &device.Logger{
|
|
|
+ outbound.tunDevice = wireTunDevice
|
|
|
+ return outbound, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (w *WireGuard) Start() error {
|
|
|
+ err := wireguard.ResolvePeers(w.ctx, w.router, w.peers)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ var bind conn.Bind
|
|
|
+ if w.useStdNetBind {
|
|
|
+ bind = conn.NewStdNetBind(w.listener.(dialer.WireGuardListener))
|
|
|
+ } else {
|
|
|
+ var (
|
|
|
+ isConnect bool
|
|
|
+ connectAddr netip.AddrPort
|
|
|
+ reserved [3]uint8
|
|
|
+ )
|
|
|
+ peerLen := len(w.peers)
|
|
|
+ if peerLen == 1 {
|
|
|
+ isConnect = true
|
|
|
+ connectAddr = w.peers[0].Endpoint
|
|
|
+ reserved = w.peers[0].Reserved
|
|
|
+ }
|
|
|
+ bind = wireguard.NewClientBind(w.ctx, w, w.listener, isConnect, connectAddr, reserved)
|
|
|
+ }
|
|
|
+ wgDevice := device.NewDevice(w.tunDevice, bind, &device.Logger{
|
|
|
Verbosef: func(format string, args ...interface{}) {
|
|
|
- logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
|
|
|
+ w.logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
|
|
|
},
|
|
|
Errorf: func(format string, args ...interface{}) {
|
|
|
- logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
|
|
|
+ w.logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
|
|
|
},
|
|
|
- }, options.Workers)
|
|
|
- if debug.Enabled {
|
|
|
- logger.Trace("created wireguard ipc conf: \n", ipcConf)
|
|
|
+ }, w.workers)
|
|
|
+ ipcConf := w.ipcConf
|
|
|
+ for _, peer := range w.peers {
|
|
|
+ ipcConf += peer.GenerateIpcLines()
|
|
|
}
|
|
|
err = wgDevice.IpcSet(ipcConf)
|
|
|
if err != nil {
|
|
|
- return nil, E.Cause(err, "setup wireguard")
|
|
|
+ return E.Cause(err, "setup wireguard: \n", ipcConf)
|
|
|
}
|
|
|
- outbound.device = wgDevice
|
|
|
- outbound.tunDevice = wireTunDevice
|
|
|
- return outbound, nil
|
|
|
+ w.device = wgDevice
|
|
|
+ w.pauseCallback = w.pauseManager.RegisterCallback(w.onPauseUpdated)
|
|
|
+ return w.tunDevice.Start()
|
|
|
+}
|
|
|
+
|
|
|
+func (w *WireGuard) Close() error {
|
|
|
+ if w.device != nil {
|
|
|
+ w.device.Close()
|
|
|
+ }
|
|
|
+ if w.pauseCallback != nil {
|
|
|
+ w.pauseManager.UnregisterCallback(w.pauseCallback)
|
|
|
+ }
|
|
|
+ w.tunDevice.Close()
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func (w *WireGuard) InterfaceUpdated() {
|
|
|
- w.bind.Reset()
|
|
|
+ w.device.BindUpdate()
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+func (w *WireGuard) onPauseUpdated(event int) {
|
|
|
+ switch event {
|
|
|
+ case pause.EventDevicePaused:
|
|
|
+ w.device.Down()
|
|
|
+ case pause.EventDeviceWake:
|
|
|
+ w.device.Up()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func (w *WireGuard) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
|
|
switch network {
|
|
|
case N.NetworkTCP:
|
|
@@ -233,15 +217,3 @@ func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata a
|
|
|
func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
|
|
return NewDirectPacketConnection(ctx, w.router, w, conn, metadata, dns.DomainStrategyAsIS)
|
|
|
}
|
|
|
-
|
|
|
-func (w *WireGuard) Start() error {
|
|
|
- return w.tunDevice.Start()
|
|
|
-}
|
|
|
-
|
|
|
-func (w *WireGuard) Close() error {
|
|
|
- if w.device != nil {
|
|
|
- w.device.Close()
|
|
|
- }
|
|
|
- w.tunDevice.Close()
|
|
|
- return nil
|
|
|
-}
|