| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 | package wireguardimport (	"context"	"errors"	"io"	"net"	"net/netip"	"strconv"	"sync"	"golang.zx2c4.com/wireguard/conn"	xnet "github.com/xtls/xray-core/common/net"	"github.com/xtls/xray-core/features/dns"	"github.com/xtls/xray-core/transport/internet")type netReadInfo struct {	// status	waiter sync.WaitGroup	// param	buff []byte	// result	bytes    int	endpoint conn.Endpoint	err      error}// reduce duplicated codetype netBind struct {	dns       dns.Client	dnsOption dns.IPOption	workers   int	readQueue chan *netReadInfo}// SetMark implements conn.Bindfunc (bind *netBind) SetMark(mark uint32) error {	return nil}// ParseEndpoint implements conn.Bindfunc (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {	ipStr, port, err := net.SplitHostPort(s)	if err != nil {		return nil, err	}	portNum, err := strconv.Atoi(port)	if err != nil {		return nil, err	}	addr := xnet.ParseAddress(ipStr)	if addr.Family() == xnet.AddressFamilyDomain {		ips, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)		if err != nil {			return nil, err		} else if len(ips) == 0 {			return nil, dns.ErrEmptyResponse		}		addr = xnet.IPAddress(ips[0])	}	dst := xnet.Destination{		Address: addr,		Port:    xnet.Port(portNum),		Network: xnet.Network_UDP,	}	return &netEndpoint{		dst: dst,	}, nil}// BatchSize implements conn.Bindfunc (bind *netBind) BatchSize() int {	return 1}// Open implements conn.Bindfunc (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {	bind.readQueue = make(chan *netReadInfo)	fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {		defer func() {			if r := recover(); r != nil {				n = 0				err = errors.New("channel closed")			}		}()		r := &netReadInfo{			buff: bufs[0],		}		r.waiter.Add(1)		bind.readQueue <- r		r.waiter.Wait() // wait read goroutine done, or we will miss the result		sizes[0], eps[0] = r.bytes, r.endpoint		return 1, r.err	}	workers := bind.workers	if workers <= 0 {		workers = 1	}	arr := make([]conn.ReceiveFunc, workers)	for i := 0; i < workers; i++ {		arr[i] = fun	}	return arr, uint16(uport), nil}// Close implements conn.Bindfunc (bind *netBind) Close() error {	if bind.readQueue != nil {		close(bind.readQueue)	}	return nil}type netBindClient struct {	netBind	ctx      context.Context	dialer   internet.Dialer	reserved []byte}func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {	c, err := bind.dialer.Dial(bind.ctx, endpoint.dst)	if err != nil {		return err	}	endpoint.conn = c	go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {		for {			v, ok := <-readQueue			if !ok {				return			}			i, err := c.Read(v.buff)			if i > 3 {				v.buff[1] = 0				v.buff[2] = 0				v.buff[3] = 0			}			v.bytes = i			v.endpoint = endpoint			v.err = err			v.waiter.Done()			if err != nil && errors.Is(err, io.EOF) {				endpoint.conn = nil				return			}		}	}(bind.readQueue, endpoint)	return nil}func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {	var err error	nend, ok := endpoint.(*netEndpoint)	if !ok {		return conn.ErrWrongEndpointType	}	if nend.conn == nil {		err = bind.connectTo(nend)		if err != nil {			return err		}	}	for _, buff := range buff {		if len(buff) > 3 && len(bind.reserved) == 3 {			copy(buff[1:], bind.reserved)		}		if _, err = nend.conn.Write(buff); err != nil {			return err		}	}	return nil}type netBindServer struct {	netBind}func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {	var err error	nend, ok := endpoint.(*netEndpoint)	if !ok {		return conn.ErrWrongEndpointType	}	if nend.conn == nil {		return errors.New("connection not open yet")	}	for _, buff := range buff {		if _, err = nend.conn.Write(buff); err != nil {			return err		}	}	return err}type netEndpoint struct {	dst  xnet.Destination	conn net.Conn}func (netEndpoint) ClearSrc() {}func (e netEndpoint) DstIP() netip.Addr {	return netip.Addr{}}func (e netEndpoint) SrcIP() netip.Addr {	return netip.Addr{}}func (e netEndpoint) DstToBytes() []byte {	var dat []byte	if e.dst.Address.Family().IsIPv4() {		dat = e.dst.Address.IP().To4()[:]	} else {		dat = e.dst.Address.IP().To16()[:]	}	dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))	return dat}func (e netEndpoint) DstToString() string {	return e.dst.NetAddr()}func (e netEndpoint) SrcToString() string {	return ""}func toNetIpAddr(addr xnet.Address) netip.Addr {	if addr.Family().IsIPv4() {		ip := addr.IP()		return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})	} else {		ip := addr.IP()		arr := [16]byte{}		for i := 0; i < 16; i++ {			arr[i] = ip[i]		}		return netip.AddrFrom16(arr)	}}
 |