123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 |
- //go:build linux && !android
- package wireguard
- import (
- "context"
- goerrors "errors"
- "fmt"
- "net"
- "net/netip"
- "os"
- "sync"
- "golang.org/x/sys/unix"
- "github.com/sagernet/sing/common/control"
- "github.com/vishvananda/netlink"
- "github.com/xtls/xray-core/common/errors"
- wgtun "golang.zx2c4.com/wireguard/tun"
- )
- type deviceNet struct {
- tunnel
- dialer net.Dialer
- handle *netlink.Handle
- linkAddrs []netlink.Addr
- routes []*netlink.Route
- rules []*netlink.Rule
- }
- var (
- tableIndex int = 10230
- mu sync.Mutex
- )
- func allocateIPv6TableIndex() int {
- mu.Lock()
- defer mu.Unlock()
- if tableIndex > 10230 {
- errors.LogInfo(context.Background(), "allocate new ipv6 table index: ", tableIndex)
- }
- currentIndex := tableIndex
- tableIndex++
- return currentIndex
- }
- func newDeviceNet(interfaceName string) *deviceNet {
- var dialer net.Dialer
- bindControl := control.BindToInterface(control.NewDefaultInterfaceFinder(), interfaceName, -1)
- dialer.Control = control.Append(dialer.Control, bindControl)
- return &deviceNet{dialer: dialer}
- }
- func (d *deviceNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
- net.Conn, error,
- ) {
- return d.dialer.DialContext(ctx, "tcp", addr.String())
- }
- func (d *deviceNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
- dialer := d.dialer
- dialer.LocalAddr = &net.UDPAddr{IP: laddr.Addr().AsSlice(), Port: int(laddr.Port())}
- return dialer.DialContext(context.Background(), "udp", raddr.String())
- }
- func (d *deviceNet) Close() (err error) {
- var errs []error
- for _, rule := range d.rules {
- if err = d.handle.RuleDel(rule); err != nil {
- errs = append(errs, fmt.Errorf("failed to delete rule: %w", err))
- }
- }
- for _, route := range d.routes {
- if err = d.handle.RouteDel(route); err != nil {
- errs = append(errs, fmt.Errorf("failed to delete route: %w", err))
- }
- }
- if err = d.tunnel.Close(); err != nil {
- errs = append(errs, fmt.Errorf("failed to close tunnel: %w", err))
- }
- if d.handle != nil {
- d.handle.Close()
- d.handle = nil
- }
- if len(errs) == 0 {
- return nil
- }
- return goerrors.Join(errs...)
- }
- func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
- if handler != nil {
- return nil, errors.New("TODO: support promiscuous mode")
- }
- var v4, v6 *netip.Addr
- for _, prefixes := range localAddresses {
- if v4 == nil && prefixes.Is4() {
- x := prefixes
- v4 = &x
- }
- if v6 == nil && prefixes.Is6() {
- x := prefixes
- v6 = &x
- }
- }
- writeSysctlZero := func(path string) error {
- _, err := os.Stat(path)
- if os.IsNotExist(err) {
- return nil
- }
- if err != nil {
- return err
- }
- return os.WriteFile(path, []byte("0"), 0o644)
- }
- // system configs.
- if v4 != nil {
- if err = writeSysctlZero("/proc/sys/net/ipv4/conf/all/rp_filter"); err != nil {
- return nil, fmt.Errorf("failed to disable ipv4 rp_filter for all: %w", err)
- }
- }
- if v6 != nil {
- if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/disable_ipv6"); err != nil {
- return nil, fmt.Errorf("failed to enable ipv6: %w", err)
- }
- if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/rp_filter"); err != nil {
- return nil, fmt.Errorf("failed to disable ipv6 rp_filter for all: %w", err)
- }
- }
- n := CalculateInterfaceName("wg")
- wgt, err := wgtun.CreateTUN(n, mtu)
- if err != nil {
- return nil, err
- }
- defer func() {
- if err != nil {
- _ = wgt.Close()
- }
- }()
- // disable linux rp_filter for tunnel device to avoid packet drop.
- // the operation require root privilege on container require '--privileged' flag.
- if v4 != nil {
- if err = writeSysctlZero("/proc/sys/net/ipv4/conf/" + n + "/rp_filter"); err != nil {
- return nil, fmt.Errorf("failed to disable ipv4 rp_filter for tunnel: %w", err)
- }
- }
- if v6 != nil {
- if err = writeSysctlZero("/proc/sys/net/ipv6/conf/" + n + "/rp_filter"); err != nil {
- return nil, fmt.Errorf("failed to disable ipv6 rp_filter for tunnel: %w", err)
- }
- }
- ipv6TableIndex := allocateIPv6TableIndex()
- if v6 != nil {
- r := &netlink.Route{Table: ipv6TableIndex}
- for {
- routeList, fErr := netlink.RouteListFiltered(netlink.FAMILY_V6, r, netlink.RT_FILTER_TABLE)
- if len(routeList) == 0 || fErr != nil {
- break
- }
- ipv6TableIndex--
- if ipv6TableIndex < 0 {
- return nil, fmt.Errorf("failed to find available ipv6 table index")
- }
- }
- }
- out := newDeviceNet(n)
- out.handle, err = netlink.NewHandle()
- if err != nil {
- return nil, err
- }
- defer func() {
- if err != nil {
- _ = out.Close()
- }
- }()
- l, err := netlink.LinkByName(n)
- if err != nil {
- return nil, err
- }
- if v4 != nil {
- addr := netlink.Addr{
- IPNet: &net.IPNet{
- IP: v4.AsSlice(),
- Mask: net.CIDRMask(v4.BitLen(), v4.BitLen()),
- },
- }
- out.linkAddrs = append(out.linkAddrs, addr)
- }
- if v6 != nil {
- addr := netlink.Addr{
- IPNet: &net.IPNet{
- IP: v6.AsSlice(),
- Mask: net.CIDRMask(v6.BitLen(), v6.BitLen()),
- },
- }
- out.linkAddrs = append(out.linkAddrs, addr)
- rt := &netlink.Route{
- LinkIndex: l.Attrs().Index,
- Dst: &net.IPNet{
- IP: net.IPv6zero,
- Mask: net.CIDRMask(0, 128),
- },
- Table: ipv6TableIndex,
- }
- out.routes = append(out.routes, rt)
- r := netlink.NewRule()
- r.Table, r.Family, r.Src = ipv6TableIndex, unix.AF_INET6, addr.IPNet
- out.rules = append(out.rules, r)
- r = netlink.NewRule()
- r.Table, r.Family, r.OifName = ipv6TableIndex, unix.AF_INET6, n
- out.rules = append(out.rules, r)
- }
- for _, addr := range out.linkAddrs {
- if err = out.handle.AddrAdd(l, &addr); err != nil {
- return nil, fmt.Errorf("failed to add address %s to %s: %w", addr, n, err)
- }
- }
- if err = out.handle.LinkSetMTU(l, mtu); err != nil {
- return nil, err
- }
- if err = out.handle.LinkSetUp(l); err != nil {
- return nil, err
- }
- for _, route := range out.routes {
- if err = out.handle.RouteAdd(route); err != nil {
- return nil, fmt.Errorf("failed to add route %s: %w", route, err)
- }
- }
- for _, rule := range out.rules {
- if err = out.handle.RuleAdd(rule); err != nil {
- return nil, fmt.Errorf("failed to add rule %s: %w", rule, err)
- }
- }
- out.tun = wgt
- return out, nil
- }
- func KernelTunSupported() (bool, error) {
- var hdr unix.CapUserHeader
- hdr.Version = unix.LINUX_CAPABILITY_VERSION_3
- hdr.Pid = 0 // 0 means current process
- var data unix.CapUserData
- if err := unix.Capget(&hdr, &data); err != nil {
- return false, fmt.Errorf("failed to get capabilities: %v", err)
- }
- return (data.Effective & (1 << unix.CAP_NET_ADMIN)) != 0, nil
- }
|