| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- package wireguard
- import (
- "context"
- "fmt"
- "net"
- "net/netip"
- "runtime"
- "strconv"
- "strings"
- "sync"
- "time"
- "github.com/xtls/xray-core/common/errors"
- "github.com/xtls/xray-core/common/log"
- xnet "github.com/xtls/xray-core/common/net"
- "github.com/xtls/xray-core/proxy/wireguard/gvisortun"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/tun"
- )
- type tunCreator func(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error)
- type promiscuousModeHandler func(dest xnet.Destination, conn net.Conn)
- type Tunnel interface {
- BuildDevice(ipc string, bind conn.Bind) error
- DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
- DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error)
- Close() error
- }
- type tunnel struct {
- tun tun.Device
- device *device.Device
- rw sync.Mutex
- }
- func (t *tunnel) BuildDevice(ipc string, bind conn.Bind) (err error) {
- t.rw.Lock()
- defer t.rw.Unlock()
- if t.device != nil {
- return errors.New("device is already initialized")
- }
- logger := &device.Logger{
- Verbosef: func(format string, args ...any) {
- log.Record(&log.GeneralMessage{
- Severity: log.Severity_Debug,
- Content: fmt.Sprintf(format, args...),
- })
- },
- Errorf: func(format string, args ...any) {
- log.Record(&log.GeneralMessage{
- Severity: log.Severity_Error,
- Content: fmt.Sprintf(format, args...),
- })
- },
- }
- t.device = device.NewDevice(t.tun, bind, logger)
- if err = t.device.IpcSet(ipc); err != nil {
- return err
- }
- if err = t.device.Up(); err != nil {
- return err
- }
- return nil
- }
- func (t *tunnel) Close() (err error) {
- t.rw.Lock()
- defer t.rw.Unlock()
- if t.device == nil {
- return nil
- }
- t.device.Close()
- t.device = nil
- err = t.tun.Close()
- t.tun = nil
- return nil
- }
- func CalculateInterfaceName(name string) (tunName string) {
- if runtime.GOOS == "darwin" {
- tunName = "utun"
- } else if name != "" {
- tunName = name
- } else {
- tunName = "tun"
- }
- interfaces, err := net.Interfaces()
- if err != nil {
- return
- }
- var tunIndex int
- for _, netInterface := range interfaces {
- if strings.HasPrefix(netInterface.Name, tunName) {
- index, parseErr := strconv.ParseInt(netInterface.Name[len(tunName):], 10, 16)
- if parseErr == nil {
- tunIndex = int(index) + 1
- }
- }
- }
- tunName = fmt.Sprintf("%s%d", tunName, tunIndex)
- return
- }
- var _ Tunnel = (*gvisorNet)(nil)
- type gvisorNet struct {
- tunnel
- net *gvisortun.Net
- }
- func (g *gvisorNet) Close() error {
- return g.tunnel.Close()
- }
- func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
- net.Conn, error,
- ) {
- return g.net.DialContextTCPAddrPort(ctx, addr)
- }
- func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
- return g.net.DialUDPAddrPort(laddr, raddr)
- }
- func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
- out := &gvisorNet{}
- tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
- if err != nil {
- return nil, err
- }
- if handler != nil {
- // handler is only used for promiscuous mode
- // capture all packets and send to handler
- tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
- go func(r *tcp.ForwarderRequest) {
- var (
- wq waiter.Queue
- id = r.ID()
- )
- // Perform a TCP three-way handshake.
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- errors.LogError(context.Background(), err.String())
- r.Complete(true)
- return
- }
- r.Complete(false)
- defer ep.Close()
- // enable tcp keep-alive to prevent hanging connections
- ep.SocketOptions().SetKeepAlive(true)
- // local address is actually destination
- handler(xnet.TCPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
- }(r)
- })
- stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
- udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) {
- go func(r *udp.ForwarderRequest) {
- var (
- wq waiter.Queue
- id = r.ID()
- )
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- errors.LogError(context.Background(), err.String())
- return
- }
- defer ep.Close()
- // prevents hanging connections and ensure timely release
- ep.SocketOptions().SetLinger(tcpip.LingerOption{
- Enabled: true,
- Timeout: 15 * time.Second,
- })
- handler(xnet.UDPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewUDPConn(stack, &wq, ep))
- }(r)
- })
- stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
- }
- out.tun, out.net = tun, n
- return out, nil
- }
|