| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361 |
- package wireguard
- import (
- "context"
- "errors"
- "net/netip"
- "strconv"
- "sync"
- "golang.zx2c4.com/wireguard/conn"
- "github.com/xtls/xray-core/common/net"
- "github.com/xtls/xray-core/features/dns"
- "github.com/xtls/xray-core/transport/internet"
- )
- type netReadInfo struct {
- // status
- waiter sync.WaitGroup
- // param
- buff []byte
- // result
- bytes int
- endpoint conn.Endpoint
- err error
- }
- // reduce duplicated code
- type netBind struct {
- dns dns.Client
- dnsOption dns.IPOption
- workers int
- readQueue chan *netReadInfo
- }
- // SetMark implements conn.Bind
- func (bind *netBind) SetMark(mark uint32) error {
- return nil
- }
- // ParseEndpoint implements conn.Bind
- func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
- ipStr, port, err := net.SplitHostPort(s)
- if err != nil {
- return nil, err
- }
- portNum, err := strconv.Atoi(port)
- if err != nil {
- return nil, err
- }
- addr := net.ParseAddress(ipStr)
- if addr.Family() == net.AddressFamilyDomain {
- ips, _, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
- if err != nil {
- return nil, err
- } else if len(ips) == 0 {
- return nil, dns.ErrEmptyResponse
- }
- addr = net.IPAddress(ips[0])
- }
- dst := net.Destination{
- Address: addr,
- Port: net.Port(portNum),
- Network: net.Network_UDP,
- }
- return &netEndpoint{
- dst: dst,
- }, nil
- }
- // BatchSize implements conn.Bind
- func (bind *netBind) BatchSize() int {
- return 1
- }
- // Open implements conn.Bind
- func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
- bind.readQueue = make(chan *netReadInfo)
- fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
- defer func() {
- if r := recover(); r != nil {
- n = 0
- err = errors.New("channel closed")
- }
- }()
- r := &netReadInfo{
- buff: bufs[0],
- }
- r.waiter.Add(1)
- bind.readQueue <- r
- r.waiter.Wait() // wait read goroutine done, or we will miss the result
- sizes[0], eps[0] = r.bytes, r.endpoint
- return 1, r.err
- }
- workers := bind.workers
- if workers <= 0 {
- workers = 1
- }
- arr := make([]conn.ReceiveFunc, workers)
- for i := 0; i < workers; i++ {
- arr[i] = fun
- }
- return arr, uint16(uport), nil
- }
- // Close implements conn.Bind
- func (bind *netBind) Close() error {
- if bind.readQueue != nil {
- close(bind.readQueue)
- }
- return nil
- }
- type netBindClient struct {
- netBind
- ctx context.Context
- dialer internet.Dialer
- reserved []byte
-
- // Track all peer connections for unified reading
- connMutex sync.RWMutex
- conns map[*netEndpoint]net.Conn
- dataChan chan *receivedData
- closeChan chan struct{}
- closeOnce sync.Once
- }
- const (
- // Buffer size for dataChan - allows some buffering of received packets
- // while dispatcher matches them with read requests
- dataChannelBufferSize = 100
- )
- type receivedData struct {
- data []byte
- n int
- endpoint *netEndpoint
- err error
- }
- func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
- c, err := bind.dialer.Dial(bind.ctx, endpoint.dst)
- if err != nil {
- return err
- }
- endpoint.conn = c
- // Initialize channels on first connection
- bind.connMutex.Lock()
- if bind.conns == nil {
- bind.conns = make(map[*netEndpoint]net.Conn)
- bind.dataChan = make(chan *receivedData, dataChannelBufferSize)
- bind.closeChan = make(chan struct{})
-
- // Start unified reader dispatcher
- go bind.unifiedReader()
- }
- bind.conns[endpoint] = c
- bind.connMutex.Unlock()
-
- // Start a reader goroutine for this specific connection
- go func(conn net.Conn, endpoint *netEndpoint) {
- const maxPacketSize = 1500
- for {
- select {
- case <-bind.closeChan:
- return
- default:
- }
-
- buf := make([]byte, maxPacketSize)
- n, err := conn.Read(buf)
-
- // Send only the valid data portion to dispatcher
- dataToSend := buf
- if n > 0 && n < len(buf) {
- dataToSend = buf[:n]
- }
-
- // Send received data to dispatcher
- select {
- case bind.dataChan <- &receivedData{
- data: dataToSend,
- n: n,
- endpoint: endpoint,
- err: err,
- }:
- case <-bind.closeChan:
- return
- }
-
- if err != nil {
- bind.connMutex.Lock()
- delete(bind.conns, endpoint)
- endpoint.conn = nil
- bind.connMutex.Unlock()
- return
- }
- }
- }(c, endpoint)
- return nil
- }
- // unifiedReader dispatches received data to waiting read requests
- func (bind *netBindClient) unifiedReader() {
- for {
- select {
- case data := <-bind.dataChan:
- // Bounds check to prevent panic
- if data.n > len(data.data) {
- data.n = len(data.data)
- }
-
- // Wait for a read request with timeout to prevent blocking forever
- select {
- case v := <-bind.readQueue:
- // Copy data to request buffer
- n := copy(v.buff, data.data[:data.n])
-
- // Clear reserved bytes if needed
- if n > 3 {
- v.buff[1] = 0
- v.buff[2] = 0
- v.buff[3] = 0
- }
-
- v.bytes = n
- v.endpoint = data.endpoint
- v.err = data.err
- v.waiter.Done()
- case <-bind.closeChan:
- return
- }
- case <-bind.closeChan:
- return
- }
- }
- }
- // Close implements conn.Bind.Close for netBindClient
- func (bind *netBindClient) Close() error {
- // Use sync.Once to prevent double-close panic
- bind.closeOnce.Do(func() {
- bind.connMutex.Lock()
- if bind.closeChan != nil {
- close(bind.closeChan)
- }
- bind.connMutex.Unlock()
- })
-
- // Call parent Close
- return bind.netBind.Close()
- }
- func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
- var err error
- nend, ok := endpoint.(*netEndpoint)
- if !ok {
- return conn.ErrWrongEndpointType
- }
- if nend.conn == nil {
- err = bind.connectTo(nend)
- if err != nil {
- return err
- }
- }
- for _, buff := range buff {
- if len(buff) > 3 && len(bind.reserved) == 3 {
- copy(buff[1:], bind.reserved)
- }
- if _, err = nend.conn.Write(buff); err != nil {
- return err
- }
- }
- return nil
- }
- type netBindServer struct {
- netBind
- }
- func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
- var err error
- nend, ok := endpoint.(*netEndpoint)
- if !ok {
- return conn.ErrWrongEndpointType
- }
- if nend.conn == nil {
- return errors.New("connection not open yet")
- }
- for _, buff := range buff {
- if _, err = nend.conn.Write(buff); err != nil {
- return err
- }
- }
- return err
- }
- type netEndpoint struct {
- dst net.Destination
- conn net.Conn
- }
- func (netEndpoint) ClearSrc() {}
- func (e netEndpoint) DstIP() netip.Addr {
- return netip.Addr{}
- }
- func (e netEndpoint) SrcIP() netip.Addr {
- return netip.Addr{}
- }
- func (e netEndpoint) DstToBytes() []byte {
- var dat []byte
- if e.dst.Address.Family().IsIPv4() {
- dat = e.dst.Address.IP().To4()[:]
- } else {
- dat = e.dst.Address.IP().To16()[:]
- }
- dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
- return dat
- }
- func (e netEndpoint) DstToString() string {
- return e.dst.NetAddr()
- }
- func (e netEndpoint) SrcToString() string {
- return ""
- }
- func toNetIpAddr(addr net.Address) netip.Addr {
- if addr.Family().IsIPv4() {
- ip := addr.IP()
- return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
- } else {
- ip := addr.IP()
- arr := [16]byte{}
- for i := 0; i < 16; i++ {
- arr[i] = ip[i]
- }
- return netip.AddrFrom16(arr)
- }
- }
|