outbound.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. package wireguard
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/hex"
  6. "fmt"
  7. "net"
  8. "net/netip"
  9. "strings"
  10. "github.com/sagernet/sing-box/adapter"
  11. "github.com/sagernet/sing-box/adapter/outbound"
  12. "github.com/sagernet/sing-box/common/dialer"
  13. C "github.com/sagernet/sing-box/constant"
  14. "github.com/sagernet/sing-box/log"
  15. "github.com/sagernet/sing-box/option"
  16. "github.com/sagernet/sing-box/transport/wireguard"
  17. "github.com/sagernet/sing-dns"
  18. "github.com/sagernet/sing-tun"
  19. "github.com/sagernet/sing/common"
  20. E "github.com/sagernet/sing/common/exceptions"
  21. "github.com/sagernet/sing/common/logger"
  22. M "github.com/sagernet/sing/common/metadata"
  23. N "github.com/sagernet/sing/common/network"
  24. "github.com/sagernet/sing/common/x/list"
  25. "github.com/sagernet/sing/service"
  26. "github.com/sagernet/sing/service/pause"
  27. "github.com/sagernet/wireguard-go/conn"
  28. "github.com/sagernet/wireguard-go/device"
  29. )
  30. func RegisterOutbound(registry *outbound.Registry) {
  31. outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound)
  32. }
  33. var _ adapter.InterfaceUpdateListener = (*Outbound)(nil)
  34. type Outbound struct {
  35. outbound.Adapter
  36. ctx context.Context
  37. router adapter.Router
  38. logger logger.ContextLogger
  39. workers int
  40. peers []wireguard.PeerConfig
  41. useStdNetBind bool
  42. listener N.Dialer
  43. ipcConf string
  44. pauseManager pause.Manager
  45. pauseCallback *list.Element[pause.Callback]
  46. bind conn.Bind
  47. device *device.Device
  48. tunDevice wireguard.Device
  49. }
  50. func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) {
  51. outbound := &Outbound{
  52. Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, options.Network.Build(), tag, options.DialerOptions),
  53. ctx: ctx,
  54. router: router,
  55. logger: logger,
  56. workers: options.Workers,
  57. pauseManager: service.FromContext[pause.Manager](ctx),
  58. }
  59. peers, err := wireguard.ParsePeers(options)
  60. if err != nil {
  61. return nil, err
  62. }
  63. outbound.peers = peers
  64. if len(options.LocalAddress) == 0 {
  65. return nil, E.New("missing local address")
  66. }
  67. if options.GSO {
  68. if options.GSO && options.Detour != "" {
  69. return nil, E.New("gso is conflict with detour")
  70. }
  71. options.IsWireGuardListener = true
  72. outbound.useStdNetBind = true
  73. }
  74. listener, err := dialer.New(ctx, options.DialerOptions)
  75. if err != nil {
  76. return nil, err
  77. }
  78. outbound.listener = listener
  79. var privateKey string
  80. {
  81. bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
  82. if err != nil {
  83. return nil, E.Cause(err, "decode private key")
  84. }
  85. privateKey = hex.EncodeToString(bytes)
  86. }
  87. outbound.ipcConf = "private_key=" + privateKey
  88. mtu := options.MTU
  89. if mtu == 0 {
  90. mtu = 1408
  91. }
  92. var wireTunDevice wireguard.Device
  93. if !options.SystemInterface && tun.WithGVisor {
  94. wireTunDevice, err = wireguard.NewStackDevice(options.LocalAddress, mtu)
  95. } else {
  96. wireTunDevice, err = wireguard.NewSystemDevice(service.FromContext[adapter.NetworkManager](ctx), options.InterfaceName, options.LocalAddress, mtu, options.GSO)
  97. }
  98. if err != nil {
  99. return nil, E.Cause(err, "create WireGuard device")
  100. }
  101. outbound.tunDevice = wireTunDevice
  102. return outbound, nil
  103. }
  104. func (w *Outbound) Start() error {
  105. if common.Any(w.peers, func(peer wireguard.PeerConfig) bool {
  106. return !peer.Endpoint.IsValid()
  107. }) {
  108. // wait for all outbounds to be started and continue in PortStart
  109. return nil
  110. }
  111. return w.start()
  112. }
  113. func (w *Outbound) PostStart() error {
  114. if common.All(w.peers, func(peer wireguard.PeerConfig) bool {
  115. return peer.Endpoint.IsValid()
  116. }) {
  117. return nil
  118. }
  119. return w.start()
  120. }
  121. func (w *Outbound) start() error {
  122. err := wireguard.ResolvePeers(w.ctx, w.router, w.peers)
  123. if err != nil {
  124. return err
  125. }
  126. var bind conn.Bind
  127. if w.useStdNetBind {
  128. bind = conn.NewStdNetBind(w.listener.(dialer.WireGuardListener))
  129. } else {
  130. var (
  131. isConnect bool
  132. connectAddr netip.AddrPort
  133. reserved [3]uint8
  134. )
  135. peerLen := len(w.peers)
  136. if peerLen == 1 {
  137. isConnect = true
  138. connectAddr = w.peers[0].Endpoint
  139. reserved = w.peers[0].Reserved
  140. }
  141. bind = wireguard.NewClientBind(w.ctx, w.logger, w.listener, isConnect, connectAddr, reserved)
  142. }
  143. if w.useStdNetBind || len(w.peers) > 1 {
  144. for _, peer := range w.peers {
  145. if peer.Reserved != [3]uint8{} {
  146. bind.SetReservedForEndpoint(peer.Endpoint, peer.Reserved)
  147. }
  148. }
  149. }
  150. err = w.tunDevice.Start()
  151. if err != nil {
  152. return err
  153. }
  154. wgDevice := device.NewDevice(w.ctx, w.tunDevice, bind, &device.Logger{
  155. Verbosef: func(format string, args ...interface{}) {
  156. w.logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
  157. },
  158. Errorf: func(format string, args ...interface{}) {
  159. w.logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
  160. },
  161. }, w.workers)
  162. ipcConf := w.ipcConf
  163. for _, peer := range w.peers {
  164. ipcConf += peer.GenerateIpcLines()
  165. }
  166. err = wgDevice.IpcSet(ipcConf)
  167. if err != nil {
  168. return E.Cause(err, "setup wireguard: \n", ipcConf)
  169. }
  170. w.device = wgDevice
  171. w.pauseCallback = w.pauseManager.RegisterCallback(w.onPauseUpdated)
  172. return nil
  173. }
  174. func (w *Outbound) Close() error {
  175. if w.device != nil {
  176. w.device.Close()
  177. }
  178. if w.pauseCallback != nil {
  179. w.pauseManager.UnregisterCallback(w.pauseCallback)
  180. }
  181. return nil
  182. }
  183. func (w *Outbound) InterfaceUpdated() {
  184. w.device.BindUpdate()
  185. return
  186. }
  187. func (w *Outbound) onPauseUpdated(event int) {
  188. switch event {
  189. case pause.EventDevicePaused:
  190. w.device.Down()
  191. case pause.EventDeviceWake:
  192. w.device.Up()
  193. }
  194. }
  195. func (w *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  196. switch network {
  197. case N.NetworkTCP:
  198. w.logger.InfoContext(ctx, "outbound connection to ", destination)
  199. case N.NetworkUDP:
  200. w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
  201. }
  202. if destination.IsFqdn() {
  203. destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
  204. if err != nil {
  205. return nil, err
  206. }
  207. return N.DialSerial(ctx, w.tunDevice, network, destination, destinationAddresses)
  208. }
  209. return w.tunDevice.DialContext(ctx, network, destination)
  210. }
  211. func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  212. w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
  213. if destination.IsFqdn() {
  214. destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
  215. if err != nil {
  216. return nil, err
  217. }
  218. packetConn, _, err := N.ListenSerial(ctx, w.tunDevice, destination, destinationAddresses)
  219. if err != nil {
  220. return nil, err
  221. }
  222. return packetConn, err
  223. }
  224. return w.tunDevice.ListenPacket(ctx, destination)
  225. }
  226. // TODO
  227. // Deprecated
  228. func (w *Outbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
  229. return outbound.NewDirectConnection(ctx, w.router, w, conn, metadata, dns.DomainStrategyAsIS)
  230. }
  231. // TODO
  232. // Deprecated
  233. func (w *Outbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
  234. return outbound.NewDirectPacketConnection(ctx, w.router, w, conn, metadata, dns.DomainStrategyAsIS)
  235. }