wireguard.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. //go:build with_wireguard
  2. package outbound
  3. import (
  4. "context"
  5. "encoding/base64"
  6. "encoding/hex"
  7. "fmt"
  8. "net"
  9. "net/netip"
  10. "os"
  11. "strings"
  12. "sync"
  13. "github.com/sagernet/sing-box/adapter"
  14. "github.com/sagernet/sing-box/common/dialer"
  15. C "github.com/sagernet/sing-box/constant"
  16. "github.com/sagernet/sing-box/log"
  17. "github.com/sagernet/sing-box/option"
  18. "github.com/sagernet/sing/common"
  19. "github.com/sagernet/sing/common/debug"
  20. E "github.com/sagernet/sing/common/exceptions"
  21. M "github.com/sagernet/sing/common/metadata"
  22. N "github.com/sagernet/sing/common/network"
  23. "golang.zx2c4.com/wireguard/conn"
  24. "golang.zx2c4.com/wireguard/device"
  25. "golang.zx2c4.com/wireguard/tun"
  26. "gvisor.dev/gvisor/pkg/bufferv2"
  27. "gvisor.dev/gvisor/pkg/tcpip"
  28. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  29. "gvisor.dev/gvisor/pkg/tcpip/header"
  30. "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  31. "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  32. "gvisor.dev/gvisor/pkg/tcpip/stack"
  33. "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
  34. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  35. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  36. )
  37. var _ adapter.Outbound = (*WireGuard)(nil)
  38. type WireGuard struct {
  39. myOutboundAdapter
  40. ctx context.Context
  41. serverAddr M.Socksaddr
  42. dialer N.Dialer
  43. endpoint conn.Endpoint
  44. device *device.Device
  45. tunDevice *wireTunDevice
  46. connAccess sync.Mutex
  47. conn *wireConn
  48. }
  49. func NewWireGuard(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (*WireGuard, error) {
  50. outbound := &WireGuard{
  51. myOutboundAdapter: myOutboundAdapter{
  52. protocol: C.TypeWireGuard,
  53. network: options.Network.Build(),
  54. router: router,
  55. logger: logger,
  56. tag: tag,
  57. },
  58. ctx: ctx,
  59. serverAddr: options.ServerOptions.Build(),
  60. dialer: dialer.NewOutbound(router, options.OutboundDialerOptions),
  61. }
  62. var endpointIp netip.Addr
  63. if !outbound.serverAddr.IsFqdn() {
  64. endpointIp = outbound.serverAddr.Addr
  65. } else {
  66. endpointIp = netip.AddrFrom4([4]byte{127, 0, 0, 1})
  67. }
  68. outbound.endpoint = conn.StdNetEndpoint(netip.AddrPortFrom(endpointIp, outbound.serverAddr.Port))
  69. localAddress := make([]tcpip.AddressWithPrefix, len(options.LocalAddress))
  70. if len(localAddress) == 0 {
  71. return nil, E.New("missing local address")
  72. }
  73. for index, address := range options.LocalAddress {
  74. if strings.Contains(address, "/") {
  75. prefix, err := netip.ParsePrefix(address)
  76. if err != nil {
  77. return nil, E.Cause(err, "parse local address prefix ", address)
  78. }
  79. localAddress[index] = tcpip.AddressWithPrefix{
  80. Address: tcpip.Address(prefix.Addr().AsSlice()),
  81. PrefixLen: prefix.Bits(),
  82. }
  83. } else {
  84. addr, err := netip.ParseAddr(address)
  85. if err != nil {
  86. return nil, E.Cause(err, "parse local address ", address)
  87. }
  88. localAddress[index] = tcpip.Address(addr.AsSlice()).WithPrefix()
  89. }
  90. }
  91. var privateKey, peerPublicKey, preSharedKey string
  92. {
  93. bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
  94. if err != nil {
  95. return nil, E.Cause(err, "decode private key")
  96. }
  97. privateKey = hex.EncodeToString(bytes)
  98. }
  99. {
  100. bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey)
  101. if err != nil {
  102. return nil, E.Cause(err, "decode peer public key")
  103. }
  104. peerPublicKey = hex.EncodeToString(bytes)
  105. }
  106. if options.PreSharedKey != "" {
  107. bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey)
  108. if err != nil {
  109. return nil, E.Cause(err, "decode pre shared key")
  110. }
  111. preSharedKey = hex.EncodeToString(bytes)
  112. }
  113. ipcConf := "private_key=" + privateKey
  114. ipcConf += "\npublic_key=" + peerPublicKey
  115. ipcConf += "\nendpoint=" + outbound.endpoint.DstToString()
  116. if preSharedKey != "" {
  117. ipcConf += "\npreshared_key=" + preSharedKey
  118. }
  119. var has4, has6 bool
  120. for _, address := range localAddress {
  121. if address.Address.To4() != "" {
  122. has4 = true
  123. } else {
  124. has6 = true
  125. }
  126. }
  127. if has4 {
  128. ipcConf += "\nallowed_ip=0.0.0.0/0"
  129. }
  130. if has6 {
  131. ipcConf += "\nallowed_ip=::/0"
  132. }
  133. mtu := options.MTU
  134. if mtu == 0 {
  135. mtu = 1408
  136. }
  137. wireDevice, err := newWireDevice(localAddress, mtu)
  138. if err != nil {
  139. return nil, err
  140. }
  141. wgDevice := device.NewDevice(wireDevice, (*wireClientBind)(outbound), &device.Logger{
  142. Verbosef: func(format string, args ...interface{}) {
  143. logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
  144. },
  145. Errorf: func(format string, args ...interface{}) {
  146. logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
  147. },
  148. })
  149. if debug.Enabled {
  150. logger.Trace("created wireguard ipc conf: \n", ipcConf)
  151. }
  152. err = wgDevice.IpcSet(ipcConf)
  153. if err != nil {
  154. return nil, E.Cause(err, "setup wireguard")
  155. }
  156. outbound.device = wgDevice
  157. outbound.tunDevice = wireDevice
  158. return outbound, nil
  159. }
  160. func (w *WireGuard) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  161. switch network {
  162. case N.NetworkTCP:
  163. w.logger.InfoContext(ctx, "outbound connection to ", destination)
  164. case N.NetworkUDP:
  165. w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
  166. }
  167. addr := tcpip.FullAddress{
  168. NIC: defaultNIC,
  169. Port: destination.Port,
  170. }
  171. if destination.IsFqdn() {
  172. addrs, err := w.router.LookupDefault(ctx, destination.Fqdn)
  173. if err != nil {
  174. return nil, err
  175. }
  176. addr.Addr = tcpip.Address(addrs[0].AsSlice())
  177. } else {
  178. addr.Addr = tcpip.Address(destination.Addr.AsSlice())
  179. }
  180. bind := tcpip.FullAddress{
  181. NIC: defaultNIC,
  182. }
  183. var networkProtocol tcpip.NetworkProtocolNumber
  184. if destination.IsIPv4() {
  185. networkProtocol = header.IPv4ProtocolNumber
  186. bind.Addr = w.tunDevice.addr4
  187. } else {
  188. networkProtocol = header.IPv6ProtocolNumber
  189. bind.Addr = w.tunDevice.addr6
  190. }
  191. switch N.NetworkName(network) {
  192. case N.NetworkTCP:
  193. return gonet.DialTCPWithBind(ctx, w.tunDevice.stack, bind, addr, networkProtocol)
  194. case N.NetworkUDP:
  195. return gonet.DialUDP(w.tunDevice.stack, &bind, &addr, networkProtocol)
  196. default:
  197. return nil, E.Extend(N.ErrUnknownNetwork, network)
  198. }
  199. }
  200. func (w *WireGuard) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  201. w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
  202. bind := tcpip.FullAddress{
  203. NIC: defaultNIC,
  204. }
  205. var networkProtocol tcpip.NetworkProtocolNumber
  206. if destination.IsIPv4() || w.tunDevice.addr6 == "" {
  207. networkProtocol = header.IPv4ProtocolNumber
  208. bind.Addr = w.tunDevice.addr4
  209. } else {
  210. networkProtocol = header.IPv6ProtocolNumber
  211. bind.Addr = w.tunDevice.addr6
  212. }
  213. return gonet.DialUDP(w.tunDevice.stack, &bind, nil, networkProtocol)
  214. }
  215. func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
  216. return NewEarlyConnection(ctx, w, conn, metadata)
  217. }
  218. func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
  219. return NewPacketConnection(ctx, w, conn, metadata)
  220. }
  221. func (w *WireGuard) Start() error {
  222. w.tunDevice.events <- tun.EventUp
  223. return nil
  224. }
  225. func (w *WireGuard) Close() error {
  226. return common.Close(
  227. common.PtrOrNil(w.tunDevice),
  228. common.PtrOrNil(w.device),
  229. common.PtrOrNil(w.conn),
  230. )
  231. }
  232. var _ conn.Bind = (*wireClientBind)(nil)
  233. type wireClientBind WireGuard
  234. func (c *wireClientBind) connect() (*wireConn, error) {
  235. c.connAccess.Lock()
  236. defer c.connAccess.Unlock()
  237. if c.conn != nil {
  238. select {
  239. case <-c.conn.done:
  240. default:
  241. return c.conn, nil
  242. }
  243. }
  244. udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr)
  245. if err != nil {
  246. return nil, &wireError{err}
  247. }
  248. c.conn = &wireConn{
  249. Conn: udpConn,
  250. done: make(chan struct{}),
  251. }
  252. return c.conn, nil
  253. }
  254. func (c *wireClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
  255. return []conn.ReceiveFunc{c.receive}, 0, nil
  256. }
  257. func (c *wireClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
  258. udpConn, err := c.connect()
  259. if err != nil {
  260. return
  261. }
  262. n, err = udpConn.Read(b)
  263. if err != nil {
  264. udpConn.Close()
  265. err = &wireError{err}
  266. }
  267. ep = c.endpoint
  268. return
  269. }
  270. func (c *wireClientBind) Close() error {
  271. c.connAccess.Lock()
  272. defer c.connAccess.Unlock()
  273. common.Close(common.PtrOrNil(c.conn))
  274. return nil
  275. }
  276. func (c *wireClientBind) SetMark(mark uint32) error {
  277. return nil
  278. }
  279. func (c *wireClientBind) Send(b []byte, ep conn.Endpoint) error {
  280. udpConn, err := c.connect()
  281. if err != nil {
  282. return err
  283. }
  284. _, err = udpConn.Write(b)
  285. if err != nil {
  286. udpConn.Close()
  287. }
  288. return err
  289. }
  290. func (c *wireClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
  291. return c.endpoint, nil
  292. }
  293. type wireError struct {
  294. cause error
  295. }
  296. func (w *wireError) Error() string {
  297. return w.cause.Error()
  298. }
  299. func (w *wireError) Timeout() bool {
  300. if cause, causeNet := w.cause.(net.Error); causeNet {
  301. return cause.Timeout()
  302. }
  303. return false
  304. }
  305. func (w *wireError) Temporary() bool {
  306. return true
  307. }
  308. func (w *wireError) Unwrap() error {
  309. return w.cause
  310. }
  311. type wireConn struct {
  312. net.Conn
  313. access sync.Mutex
  314. done chan struct{}
  315. }
  316. func (w *wireConn) Close() error {
  317. w.access.Lock()
  318. defer w.access.Unlock()
  319. select {
  320. case <-w.done:
  321. return net.ErrClosed
  322. default:
  323. }
  324. w.Conn.Close()
  325. close(w.done)
  326. return nil
  327. }
  328. var _ tun.Device = (*wireTunDevice)(nil)
  329. const defaultNIC tcpip.NICID = 1
  330. type wireTunDevice struct {
  331. stack *stack.Stack
  332. mtu uint32
  333. events chan tun.Event
  334. outbound chan *stack.PacketBuffer
  335. dispatcher stack.NetworkDispatcher
  336. done chan struct{}
  337. addr4 tcpip.Address
  338. addr6 tcpip.Address
  339. }
  340. func newWireDevice(localAddresses []tcpip.AddressWithPrefix, mtu uint32) (*wireTunDevice, error) {
  341. ipStack := stack.New(stack.Options{
  342. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  343. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
  344. HandleLocal: true,
  345. })
  346. tunDevice := &wireTunDevice{
  347. stack: ipStack,
  348. mtu: mtu,
  349. events: make(chan tun.Event, 4),
  350. outbound: make(chan *stack.PacketBuffer, 256),
  351. done: make(chan struct{}),
  352. }
  353. err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
  354. if err != nil {
  355. return nil, E.New(err.String())
  356. }
  357. for _, addr := range localAddresses {
  358. var protoAddr tcpip.ProtocolAddress
  359. if len(addr.Address) == net.IPv4len {
  360. tunDevice.addr4 = addr.Address
  361. protoAddr = tcpip.ProtocolAddress{
  362. Protocol: ipv4.ProtocolNumber,
  363. AddressWithPrefix: addr,
  364. }
  365. } else {
  366. tunDevice.addr6 = addr.Address
  367. protoAddr = tcpip.ProtocolAddress{
  368. Protocol: ipv6.ProtocolNumber,
  369. AddressWithPrefix: addr,
  370. }
  371. }
  372. err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
  373. if err != nil {
  374. return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
  375. }
  376. }
  377. sOpt := tcpip.TCPSACKEnabled(true)
  378. ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
  379. cOpt := tcpip.CongestionControlOption("cubic")
  380. ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
  381. ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
  382. ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
  383. return tunDevice, nil
  384. }
  385. func (w *wireTunDevice) File() *os.File {
  386. return nil
  387. }
  388. func (w *wireTunDevice) Read(p []byte, offset int) (n int, err error) {
  389. packetBuffer, ok := <-w.outbound
  390. if !ok {
  391. return 0, os.ErrClosed
  392. }
  393. defer packetBuffer.DecRef()
  394. p = p[offset:]
  395. for _, slice := range packetBuffer.AsSlices() {
  396. n += copy(p[n:], slice)
  397. }
  398. return
  399. }
  400. func (w *wireTunDevice) Write(p []byte, offset int) (n int, err error) {
  401. p = p[offset:]
  402. if len(p) == 0 {
  403. return
  404. }
  405. var networkProtocol tcpip.NetworkProtocolNumber
  406. switch header.IPVersion(p) {
  407. case header.IPv4Version:
  408. networkProtocol = header.IPv4ProtocolNumber
  409. case header.IPv6Version:
  410. networkProtocol = header.IPv6ProtocolNumber
  411. }
  412. packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
  413. Payload: bufferv2.MakeWithData(p),
  414. })
  415. defer packetBuffer.DecRef()
  416. w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
  417. n = len(p)
  418. return
  419. }
  420. func (w *wireTunDevice) Flush() error {
  421. return nil
  422. }
  423. func (w *wireTunDevice) MTU() (int, error) {
  424. return int(w.mtu), nil
  425. }
  426. func (w *wireTunDevice) Name() (string, error) {
  427. return "sing-box", nil
  428. }
  429. func (w *wireTunDevice) Events() chan tun.Event {
  430. return w.events
  431. }
  432. func (w *wireTunDevice) Close() error {
  433. select {
  434. case <-w.done:
  435. return os.ErrClosed
  436. default:
  437. }
  438. close(w.done)
  439. w.stack.Close()
  440. for _, endpoint := range w.stack.CleanupEndpoints() {
  441. endpoint.Abort()
  442. }
  443. w.stack.Wait()
  444. close(w.outbound)
  445. return nil
  446. }
  447. var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
  448. type wireEndpoint wireTunDevice
  449. func (ep *wireEndpoint) MTU() uint32 {
  450. return ep.mtu
  451. }
  452. func (ep *wireEndpoint) MaxHeaderLength() uint16 {
  453. return 0
  454. }
  455. func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
  456. return ""
  457. }
  458. func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
  459. return stack.CapabilityNone
  460. }
  461. func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
  462. ep.dispatcher = dispatcher
  463. }
  464. func (ep *wireEndpoint) IsAttached() bool {
  465. return ep.dispatcher != nil
  466. }
  467. func (ep *wireEndpoint) Wait() {
  468. }
  469. func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
  470. return header.ARPHardwareNone
  471. }
  472. func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
  473. }
  474. func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
  475. for _, packetBuffer := range list.AsSlice() {
  476. packetBuffer.IncRef()
  477. ep.outbound <- packetBuffer
  478. }
  479. return list.Len(), nil
  480. }