tun_linux.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. //go:build linux && !android
  2. package wireguard
  3. import (
  4. "context"
  5. goerrors "errors"
  6. "fmt"
  7. "net"
  8. "net/netip"
  9. "os"
  10. "sync"
  11. "golang.org/x/sys/unix"
  12. "github.com/sagernet/sing/common/control"
  13. "github.com/vishvananda/netlink"
  14. "github.com/xtls/xray-core/common/errors"
  15. wgtun "golang.zx2c4.com/wireguard/tun"
  16. )
  17. type deviceNet struct {
  18. tunnel
  19. dialer net.Dialer
  20. handle *netlink.Handle
  21. linkAddrs []netlink.Addr
  22. routes []*netlink.Route
  23. rules []*netlink.Rule
  24. }
  25. var (
  26. tableIndex int = 10230
  27. mu sync.Mutex
  28. )
  29. func allocateIPv6TableIndex() int {
  30. mu.Lock()
  31. defer mu.Unlock()
  32. if tableIndex > 10230 {
  33. errors.LogInfo(context.Background(), "allocate new ipv6 table index: ", tableIndex)
  34. }
  35. currentIndex := tableIndex
  36. tableIndex++
  37. return currentIndex
  38. }
  39. func newDeviceNet(interfaceName string) *deviceNet {
  40. var dialer net.Dialer
  41. bindControl := control.BindToInterface(control.NewDefaultInterfaceFinder(), interfaceName, -1)
  42. dialer.Control = control.Append(dialer.Control, bindControl)
  43. return &deviceNet{dialer: dialer}
  44. }
  45. func (d *deviceNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
  46. net.Conn, error,
  47. ) {
  48. return d.dialer.DialContext(ctx, "tcp", addr.String())
  49. }
  50. func (d *deviceNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
  51. dialer := d.dialer
  52. dialer.LocalAddr = &net.UDPAddr{IP: laddr.Addr().AsSlice(), Port: int(laddr.Port())}
  53. return dialer.DialContext(context.Background(), "udp", raddr.String())
  54. }
  55. func (d *deviceNet) Close() (err error) {
  56. var errs []error
  57. for _, rule := range d.rules {
  58. if err = d.handle.RuleDel(rule); err != nil {
  59. errs = append(errs, fmt.Errorf("failed to delete rule: %w", err))
  60. }
  61. }
  62. for _, route := range d.routes {
  63. if err = d.handle.RouteDel(route); err != nil {
  64. errs = append(errs, fmt.Errorf("failed to delete route: %w", err))
  65. }
  66. }
  67. if err = d.tunnel.Close(); err != nil {
  68. errs = append(errs, fmt.Errorf("failed to close tunnel: %w", err))
  69. }
  70. if d.handle != nil {
  71. d.handle.Close()
  72. d.handle = nil
  73. }
  74. if len(errs) == 0 {
  75. return nil
  76. }
  77. return goerrors.Join(errs...)
  78. }
  79. func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
  80. if handler != nil {
  81. return nil, errors.New("TODO: support promiscuous mode")
  82. }
  83. var v4, v6 *netip.Addr
  84. for _, prefixes := range localAddresses {
  85. if v4 == nil && prefixes.Is4() {
  86. x := prefixes
  87. v4 = &x
  88. }
  89. if v6 == nil && prefixes.Is6() {
  90. x := prefixes
  91. v6 = &x
  92. }
  93. }
  94. writeSysctlZero := func(path string) error {
  95. _, err := os.Stat(path)
  96. if os.IsNotExist(err) {
  97. return nil
  98. }
  99. if err != nil {
  100. return err
  101. }
  102. return os.WriteFile(path, []byte("0"), 0o644)
  103. }
  104. // system configs.
  105. if v4 != nil {
  106. if err = writeSysctlZero("/proc/sys/net/ipv4/conf/all/rp_filter"); err != nil {
  107. return nil, fmt.Errorf("failed to disable ipv4 rp_filter for all: %w", err)
  108. }
  109. }
  110. if v6 != nil {
  111. if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/disable_ipv6"); err != nil {
  112. return nil, fmt.Errorf("failed to enable ipv6: %w", err)
  113. }
  114. if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/rp_filter"); err != nil {
  115. return nil, fmt.Errorf("failed to disable ipv6 rp_filter for all: %w", err)
  116. }
  117. }
  118. n := CalculateInterfaceName("wg")
  119. wgt, err := wgtun.CreateTUN(n, mtu)
  120. if err != nil {
  121. return nil, err
  122. }
  123. defer func() {
  124. if err != nil {
  125. _ = wgt.Close()
  126. }
  127. }()
  128. // disable linux rp_filter for tunnel device to avoid packet drop.
  129. // the operation require root privilege on container require '--privileged' flag.
  130. if v4 != nil {
  131. if err = writeSysctlZero("/proc/sys/net/ipv4/conf/" + n + "/rp_filter"); err != nil {
  132. return nil, fmt.Errorf("failed to disable ipv4 rp_filter for tunnel: %w", err)
  133. }
  134. }
  135. if v6 != nil {
  136. if err = writeSysctlZero("/proc/sys/net/ipv6/conf/" + n + "/rp_filter"); err != nil {
  137. return nil, fmt.Errorf("failed to disable ipv6 rp_filter for tunnel: %w", err)
  138. }
  139. }
  140. ipv6TableIndex := allocateIPv6TableIndex()
  141. if v6 != nil {
  142. r := &netlink.Route{Table: ipv6TableIndex}
  143. for {
  144. routeList, fErr := netlink.RouteListFiltered(netlink.FAMILY_V6, r, netlink.RT_FILTER_TABLE)
  145. if len(routeList) == 0 || fErr != nil {
  146. break
  147. }
  148. ipv6TableIndex--
  149. if ipv6TableIndex < 0 {
  150. return nil, fmt.Errorf("failed to find available ipv6 table index")
  151. }
  152. }
  153. }
  154. out := newDeviceNet(n)
  155. out.handle, err = netlink.NewHandle()
  156. if err != nil {
  157. return nil, err
  158. }
  159. defer func() {
  160. if err != nil {
  161. _ = out.Close()
  162. }
  163. }()
  164. l, err := netlink.LinkByName(n)
  165. if err != nil {
  166. return nil, err
  167. }
  168. if v4 != nil {
  169. addr := netlink.Addr{
  170. IPNet: &net.IPNet{
  171. IP: v4.AsSlice(),
  172. Mask: net.CIDRMask(v4.BitLen(), v4.BitLen()),
  173. },
  174. }
  175. out.linkAddrs = append(out.linkAddrs, addr)
  176. }
  177. if v6 != nil {
  178. addr := netlink.Addr{
  179. IPNet: &net.IPNet{
  180. IP: v6.AsSlice(),
  181. Mask: net.CIDRMask(v6.BitLen(), v6.BitLen()),
  182. },
  183. }
  184. out.linkAddrs = append(out.linkAddrs, addr)
  185. rt := &netlink.Route{
  186. LinkIndex: l.Attrs().Index,
  187. Dst: &net.IPNet{
  188. IP: net.IPv6zero,
  189. Mask: net.CIDRMask(0, 128),
  190. },
  191. Table: ipv6TableIndex,
  192. }
  193. out.routes = append(out.routes, rt)
  194. r := netlink.NewRule()
  195. r.Table, r.Family, r.Src = ipv6TableIndex, unix.AF_INET6, addr.IPNet
  196. out.rules = append(out.rules, r)
  197. r = netlink.NewRule()
  198. r.Table, r.Family, r.OifName = ipv6TableIndex, unix.AF_INET6, n
  199. out.rules = append(out.rules, r)
  200. }
  201. for _, addr := range out.linkAddrs {
  202. if err = out.handle.AddrAdd(l, &addr); err != nil {
  203. return nil, fmt.Errorf("failed to add address %s to %s: %w", addr, n, err)
  204. }
  205. }
  206. if err = out.handle.LinkSetMTU(l, mtu); err != nil {
  207. return nil, err
  208. }
  209. if err = out.handle.LinkSetUp(l); err != nil {
  210. return nil, err
  211. }
  212. for _, route := range out.routes {
  213. if err = out.handle.RouteAdd(route); err != nil {
  214. return nil, fmt.Errorf("failed to add route %s: %w", route, err)
  215. }
  216. }
  217. for _, rule := range out.rules {
  218. if err = out.handle.RuleAdd(rule); err != nil {
  219. return nil, fmt.Errorf("failed to add rule %s: %w", rule, err)
  220. }
  221. }
  222. out.tun = wgt
  223. return out, nil
  224. }
  225. func KernelTunSupported() (bool, error) {
  226. var hdr unix.CapUserHeader
  227. hdr.Version = unix.LINUX_CAPABILITY_VERSION_3
  228. hdr.Pid = 0 // 0 means current process
  229. var data unix.CapUserData
  230. if err := unix.Capget(&hdr, &data); err != nil {
  231. return false, fmt.Errorf("failed to get capabilities: %v", err)
  232. }
  233. return (data.Effective & (1 << unix.CAP_NET_ADMIN)) != 0, nil
  234. }