package wireguard import ( "context" "errors" "net/netip" "strconv" "sync" "golang.zx2c4.com/wireguard/conn" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/transport/internet" ) type netReadInfo struct { // status waiter sync.WaitGroup // param buff []byte // result bytes int endpoint conn.Endpoint err error } // reduce duplicated code type netBind struct { dns dns.Client dnsOption dns.IPOption workers int readQueue chan *netReadInfo } // SetMark implements conn.Bind func (bind *netBind) SetMark(mark uint32) error { return nil } // ParseEndpoint implements conn.Bind func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) { ipStr, port, err := net.SplitHostPort(s) if err != nil { return nil, err } portNum, err := strconv.Atoi(port) if err != nil { return nil, err } addr := net.ParseAddress(ipStr) if addr.Family() == net.AddressFamilyDomain { ips, _, err := n.dns.LookupIP(addr.Domain(), n.dnsOption) if err != nil { return nil, err } else if len(ips) == 0 { return nil, dns.ErrEmptyResponse } addr = net.IPAddress(ips[0]) } dst := net.Destination{ Address: addr, Port: net.Port(portNum), Network: net.Network_UDP, } return &netEndpoint{ dst: dst, }, nil } // BatchSize implements conn.Bind func (bind *netBind) BatchSize() int { return 1 } // Open implements conn.Bind func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { bind.readQueue = make(chan *netReadInfo) fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { defer func() { if r := recover(); r != nil { n = 0 err = errors.New("channel closed") } }() r := &netReadInfo{ buff: bufs[0], } r.waiter.Add(1) bind.readQueue <- r r.waiter.Wait() // wait read goroutine done, or we will miss the result sizes[0], eps[0] = r.bytes, r.endpoint return 1, r.err } workers := bind.workers if workers <= 0 { workers = 1 } arr := make([]conn.ReceiveFunc, workers) for i := 0; i < workers; i++ { arr[i] = fun } return arr, uint16(uport), nil } // Close implements conn.Bind func (bind *netBind) Close() error { if bind.readQueue != nil { close(bind.readQueue) } return nil } type netBindClient struct { netBind ctx context.Context dialer internet.Dialer reserved []byte // Track all peer connections for unified reading connMutex sync.RWMutex conns map[*netEndpoint]net.Conn dataChan chan *receivedData closeChan chan struct{} closeOnce sync.Once } const ( // Buffer size for dataChan - allows some buffering of received packets // while dispatcher matches them with read requests dataChannelBufferSize = 100 ) type receivedData struct { data []byte n int endpoint *netEndpoint err error } func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { c, err := bind.dialer.Dial(bind.ctx, endpoint.dst) if err != nil { return err } endpoint.conn = c // Initialize channels on first connection bind.connMutex.Lock() if bind.conns == nil { bind.conns = make(map[*netEndpoint]net.Conn) bind.dataChan = make(chan *receivedData, dataChannelBufferSize) bind.closeChan = make(chan struct{}) // Start unified reader dispatcher go bind.unifiedReader() } bind.conns[endpoint] = c bind.connMutex.Unlock() // Start a reader goroutine for this specific connection go func(conn net.Conn, endpoint *netEndpoint) { const maxPacketSize = 1500 for { select { case <-bind.closeChan: return default: } buf := make([]byte, maxPacketSize) n, err := conn.Read(buf) // Send only the valid data portion to dispatcher dataToSend := buf if n > 0 && n < len(buf) { dataToSend = buf[:n] } // Send received data to dispatcher select { case bind.dataChan <- &receivedData{ data: dataToSend, n: n, endpoint: endpoint, err: err, }: case <-bind.closeChan: return } if err != nil { bind.connMutex.Lock() delete(bind.conns, endpoint) endpoint.conn = nil bind.connMutex.Unlock() return } } }(c, endpoint) return nil } // unifiedReader dispatches received data to waiting read requests func (bind *netBindClient) unifiedReader() { for { select { case data := <-bind.dataChan: // Bounds check to prevent panic if data.n > len(data.data) { data.n = len(data.data) } // Wait for a read request with timeout to prevent blocking forever select { case v := <-bind.readQueue: // Copy data to request buffer n := copy(v.buff, data.data[:data.n]) // Clear reserved bytes if needed if n > 3 { v.buff[1] = 0 v.buff[2] = 0 v.buff[3] = 0 } v.bytes = n v.endpoint = data.endpoint v.err = data.err v.waiter.Done() case <-bind.closeChan: return } case <-bind.closeChan: return } } } // Close implements conn.Bind.Close for netBindClient func (bind *netBindClient) Close() error { // Use sync.Once to prevent double-close panic bind.closeOnce.Do(func() { bind.connMutex.Lock() if bind.closeChan != nil { close(bind.closeChan) } bind.connMutex.Unlock() }) // Call parent Close return bind.netBind.Close() } func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error { var err error nend, ok := endpoint.(*netEndpoint) if !ok { return conn.ErrWrongEndpointType } if nend.conn == nil { err = bind.connectTo(nend) if err != nil { return err } } for _, buff := range buff { if len(buff) > 3 && len(bind.reserved) == 3 { copy(buff[1:], bind.reserved) } if _, err = nend.conn.Write(buff); err != nil { return err } } return nil } type netBindServer struct { netBind } func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error { var err error nend, ok := endpoint.(*netEndpoint) if !ok { return conn.ErrWrongEndpointType } if nend.conn == nil { return errors.New("connection not open yet") } for _, buff := range buff { if _, err = nend.conn.Write(buff); err != nil { return err } } return err } type netEndpoint struct { dst net.Destination conn net.Conn } func (netEndpoint) ClearSrc() {} func (e netEndpoint) DstIP() netip.Addr { return netip.Addr{} } func (e netEndpoint) SrcIP() netip.Addr { return netip.Addr{} } func (e netEndpoint) DstToBytes() []byte { var dat []byte if e.dst.Address.Family().IsIPv4() { dat = e.dst.Address.IP().To4()[:] } else { dat = e.dst.Address.IP().To16()[:] } dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8)) return dat } func (e netEndpoint) DstToString() string { return e.dst.NetAddr() } func (e netEndpoint) SrcToString() string { return "" } func toNetIpAddr(addr net.Address) netip.Addr { if addr.Family().IsIPv4() { ip := addr.IP() return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]}) } else { ip := addr.IP() arr := [16]byte{} for i := 0; i < 16; i++ { arr[i] = ip[i] } return netip.AddrFrom16(arr) } }