| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 | 
							- package route
 
- import (
 
- 	"context"
 
- 	"errors"
 
- 	"io"
 
- 	"net"
 
- 	"net/netip"
 
- 	"os"
 
- 	"sync"
 
- 	"sync/atomic"
 
- 	"time"
 
- 	"github.com/sagernet/sing-box/adapter"
 
- 	"github.com/sagernet/sing-box/common/dialer"
 
- 	C "github.com/sagernet/sing-box/constant"
 
- 	"github.com/sagernet/sing/common"
 
- 	"github.com/sagernet/sing/common/buf"
 
- 	"github.com/sagernet/sing/common/bufio"
 
- 	"github.com/sagernet/sing/common/canceler"
 
- 	E "github.com/sagernet/sing/common/exceptions"
 
- 	"github.com/sagernet/sing/common/logger"
 
- 	M "github.com/sagernet/sing/common/metadata"
 
- 	N "github.com/sagernet/sing/common/network"
 
- 	"github.com/sagernet/sing/common/x/list"
 
- )
 
- var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
 
- type ConnectionManager struct {
 
- 	logger      logger.ContextLogger
 
- 	access      sync.Mutex
 
- 	connections list.List[io.Closer]
 
- }
 
- func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
 
- 	return &ConnectionManager{
 
- 		logger: logger,
 
- 	}
 
- }
 
- func (m *ConnectionManager) Start(stage adapter.StartStage) error {
 
- 	return nil
 
- }
 
- func (m *ConnectionManager) Close() error {
 
- 	m.access.Lock()
 
- 	defer m.access.Unlock()
 
- 	for element := m.connections.Front(); element != nil; element = element.Next() {
 
- 		common.Close(element.Value)
 
- 	}
 
- 	m.connections.Init()
 
- 	return nil
 
- }
 
- func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
 
- 	ctx = adapter.WithContext(ctx, &metadata)
 
- 	var (
 
- 		remoteConn net.Conn
 
- 		err        error
 
- 	)
 
- 	if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
 
- 		remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
 
- 	} else {
 
- 		remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
 
- 	}
 
- 	if err != nil {
 
- 		err = E.Cause(err, "open outbound connection")
 
- 		N.CloseOnHandshakeFailure(conn, onClose, err)
 
- 		m.logger.ErrorContext(ctx, err)
 
- 		return
 
- 	}
 
- 	err = N.ReportConnHandshakeSuccess(conn, remoteConn)
 
- 	if err != nil {
 
- 		err = E.Cause(err, "report handshake success")
 
- 		remoteConn.Close()
 
- 		N.CloseOnHandshakeFailure(conn, onClose, err)
 
- 		m.logger.ErrorContext(ctx, err)
 
- 		return
 
- 	}
 
- 	m.access.Lock()
 
- 	element := m.connections.PushBack(conn)
 
- 	m.access.Unlock()
 
- 	onClose = N.AppendClose(onClose, func(it error) {
 
- 		m.access.Lock()
 
- 		defer m.access.Unlock()
 
- 		m.connections.Remove(element)
 
- 	})
 
- 	var done atomic.Bool
 
- 	go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
 
- 	go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
 
- }
 
- func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
 
- 	ctx = adapter.WithContext(ctx, &metadata)
 
- 	var (
 
- 		remotePacketConn   net.PacketConn
 
- 		remoteConn         net.Conn
 
- 		destinationAddress netip.Addr
 
- 		err                error
 
- 	)
 
- 	if metadata.UDPConnect {
 
- 		parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer)
 
- 		if len(metadata.DestinationAddresses) > 0 {
 
- 			if isParallelDialer {
 
- 				remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
 
- 			} else {
 
- 				remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
 
- 			}
 
- 		} else if metadata.Destination.IsIP() {
 
- 			if isParallelDialer {
 
- 				remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
 
- 			} else {
 
- 				remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
 
- 			}
 
- 		} else {
 
- 			remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
 
- 		}
 
- 		if err != nil {
 
- 			N.CloseOnHandshakeFailure(conn, onClose, err)
 
- 			m.logger.ErrorContext(ctx, "open outbound packet connection: ", err)
 
- 			return
 
- 		}
 
- 		remotePacketConn = bufio.NewUnbindPacketConn(remoteConn)
 
- 		connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr())
 
- 		if connRemoteAddr != metadata.Destination.Addr {
 
- 			destinationAddress = connRemoteAddr
 
- 		}
 
- 	} else {
 
- 		if len(metadata.DestinationAddresses) > 0 {
 
- 			remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
 
- 		} else {
 
- 			remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination)
 
- 		}
 
- 		if err != nil {
 
- 			N.CloseOnHandshakeFailure(conn, onClose, err)
 
- 			m.logger.ErrorContext(ctx, "listen outbound packet connection: ", err)
 
- 			return
 
- 		}
 
- 	}
 
- 	err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn)
 
- 	if err != nil {
 
- 		conn.Close()
 
- 		remotePacketConn.Close()
 
- 		m.logger.ErrorContext(ctx, "report handshake success: ", err)
 
- 		return
 
- 	}
 
- 	if destinationAddress.IsValid() {
 
- 		var originDestination M.Socksaddr
 
- 		if metadata.RouteOriginalDestination.IsValid() {
 
- 			originDestination = metadata.RouteOriginalDestination
 
- 		} else {
 
- 			originDestination = metadata.Destination
 
- 		}
 
- 		if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
 
- 			natConn.UpdateDestination(destinationAddress)
 
- 		} else if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
 
- 			if metadata.UDPDisableDomainUnmapping {
 
- 				remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
 
- 			} else {
 
- 				remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
 
- 			}
 
- 		}
 
- 	} else if metadata.RouteOriginalDestination.IsValid() && metadata.RouteOriginalDestination != metadata.Destination {
 
- 		remotePacketConn = bufio.NewDestinationNATPacketConn(bufio.NewPacketConn(remotePacketConn), metadata.Destination, metadata.RouteOriginalDestination)
 
- 	}
 
- 	var udpTimeout time.Duration
 
- 	if metadata.UDPTimeout > 0 {
 
- 		udpTimeout = metadata.UDPTimeout
 
- 	} else {
 
- 		protocol := metadata.Protocol
 
- 		if protocol == "" {
 
- 			protocol = C.PortProtocols[metadata.Destination.Port]
 
- 		}
 
- 		if protocol != "" {
 
- 			udpTimeout = C.ProtocolTimeouts[protocol]
 
- 		}
 
- 	}
 
- 	if udpTimeout > 0 {
 
- 		ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout)
 
- 	}
 
- 	destination := bufio.NewPacketConn(remotePacketConn)
 
- 	m.access.Lock()
 
- 	element := m.connections.PushBack(conn)
 
- 	m.access.Unlock()
 
- 	onClose = N.AppendClose(onClose, func(it error) {
 
- 		m.access.Lock()
 
- 		defer m.access.Unlock()
 
- 		m.connections.Remove(element)
 
- 	})
 
- 	var done atomic.Bool
 
- 	go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
 
- 	go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
 
- }
 
- func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
 
- 	var (
 
- 		sourceReader      io.Reader = source
 
- 		destinationWriter io.Writer = destination
 
- 	)
 
- 	var readCounters, writeCounters []N.CountFunc
 
- 	for {
 
- 		sourceReader, readCounters = N.UnwrapCountReader(sourceReader, readCounters)
 
- 		destinationWriter, writeCounters = N.UnwrapCountWriter(destinationWriter, writeCounters)
 
- 		if cachedSrc, isCached := sourceReader.(N.CachedReader); isCached {
 
- 			cachedBuffer := cachedSrc.ReadCached()
 
- 			if cachedBuffer != nil {
 
- 				dataLen := cachedBuffer.Len()
 
- 				_, err := destination.Write(cachedBuffer.Bytes())
 
- 				cachedBuffer.Release()
 
- 				if err != nil {
 
- 					if done.Swap(true) {
 
- 						onClose(err)
 
- 					}
 
- 					common.Close(source, destination)
 
- 					if !direction {
 
- 						m.logger.ErrorContext(ctx, "connection upload payload: ", err)
 
- 					} else {
 
- 						m.logger.ErrorContext(ctx, "connection download payload: ", err)
 
- 					}
 
- 					return
 
- 				}
 
- 				for _, counter := range readCounters {
 
- 					counter(int64(dataLen))
 
- 				}
 
- 				for _, counter := range writeCounters {
 
- 					counter(int64(dataLen))
 
- 				}
 
- 			}
 
- 			continue
 
- 		}
 
- 		break
 
- 	}
 
- 	if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destinationWriter); isEarlyConn && earlyConn.NeedHandshake() {
 
- 		err := m.connectionCopyEarly(source, destination)
 
- 		if err != nil {
 
- 			if done.Swap(true) {
 
- 				onClose(err)
 
- 			}
 
- 			common.Close(source, destination)
 
- 			if !direction {
 
- 				m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
 
- 			} else {
 
- 				m.logger.ErrorContext(ctx, "connection download handshake: ", err)
 
- 			}
 
- 			return
 
- 		}
 
- 	}
 
- 	_, err := bufio.CopyWithCounters(destination, sourceReader, source, readCounters, writeCounters)
 
- 	if err != nil {
 
- 		common.Close(source, destination)
 
- 	} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
 
- 		err = duplexDst.CloseWrite()
 
- 		if err != nil {
 
- 			common.Close(source, destination)
 
- 		}
 
- 	} else {
 
- 		destination.Close()
 
- 	}
 
- 	if done.Swap(true) {
 
- 		onClose(err)
 
- 		common.Close(source, destination)
 
- 	}
 
- 	if !direction {
 
- 		if err == nil {
 
- 			m.logger.DebugContext(ctx, "connection upload finished")
 
- 		} else if !E.IsClosedOrCanceled(err) {
 
- 			m.logger.ErrorContext(ctx, "connection upload closed: ", err)
 
- 		} else {
 
- 			m.logger.TraceContext(ctx, "connection upload closed")
 
- 		}
 
- 	} else {
 
- 		if err == nil {
 
- 			m.logger.DebugContext(ctx, "connection download finished")
 
- 		} else if !E.IsClosedOrCanceled(err) {
 
- 			m.logger.ErrorContext(ctx, "connection download closed: ", err)
 
- 		} else {
 
- 			m.logger.TraceContext(ctx, "connection download closed")
 
- 		}
 
- 	}
 
- }
 
- func (m *ConnectionManager) connectionCopyEarly(source net.Conn, destination io.Writer) error {
 
- 	payload := buf.NewPacket()
 
- 	defer payload.Release()
 
- 	err := source.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
 
- 	if err != nil {
 
- 		if err == os.ErrInvalid {
 
- 			return common.Error(destination.Write(nil))
 
- 		}
 
- 		return err
 
- 	}
 
- 	_, err = payload.ReadOnceFrom(source)
 
- 	if err != nil && !(E.IsTimeout(err) || errors.Is(err, io.EOF)) {
 
- 		return E.Cause(err, "read payload")
 
- 	}
 
- 	_ = source.SetReadDeadline(time.Time{})
 
- 	_, err = destination.Write(payload.Bytes())
 
- 	if err != nil {
 
- 		return E.Cause(err, "write payload")
 
- 	}
 
- 	return nil
 
- }
 
- func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
 
- 	_, err := bufio.CopyPacket(destination, source)
 
- 	if !direction {
 
- 		if err == nil {
 
- 			m.logger.DebugContext(ctx, "packet upload finished")
 
- 		} else if E.IsClosedOrCanceled(err) {
 
- 			m.logger.TraceContext(ctx, "packet upload closed")
 
- 		} else {
 
- 			m.logger.DebugContext(ctx, "packet upload closed: ", err)
 
- 		}
 
- 	} else {
 
- 		if err == nil {
 
- 			m.logger.DebugContext(ctx, "packet download finished")
 
- 		} else if E.IsClosedOrCanceled(err) {
 
- 			m.logger.TraceContext(ctx, "packet download closed")
 
- 		} else {
 
- 			m.logger.DebugContext(ctx, "packet download closed: ", err)
 
- 		}
 
- 	}
 
- 	if !done.Swap(true) {
 
- 		onClose(err)
 
- 	}
 
- 	common.Close(source, destination)
 
- }
 
 
  |