123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395 |
- package route
- import (
- "context"
- "errors"
- "io"
- "net"
- "net/netip"
- "os"
- "strings"
- "sync"
- "sync/atomic"
- "time"
- "github.com/sagernet/sing-box/adapter"
- "github.com/sagernet/sing-box/common/dialer"
- "github.com/sagernet/sing-box/common/tlsfragment"
- C "github.com/sagernet/sing-box/constant"
- "github.com/sagernet/sing/common"
- "github.com/sagernet/sing/common/buf"
- "github.com/sagernet/sing/common/bufio"
- "github.com/sagernet/sing/common/canceler"
- E "github.com/sagernet/sing/common/exceptions"
- "github.com/sagernet/sing/common/logger"
- M "github.com/sagernet/sing/common/metadata"
- N "github.com/sagernet/sing/common/network"
- "github.com/sagernet/sing/common/x/list"
- )
- var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
- type ConnectionManager struct {
- logger logger.ContextLogger
- access sync.Mutex
- connections list.List[io.Closer]
- }
- func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
- return &ConnectionManager{
- logger: logger,
- }
- }
- func (m *ConnectionManager) Start(stage adapter.StartStage) error {
- return nil
- }
- func (m *ConnectionManager) Close() error {
- m.access.Lock()
- defer m.access.Unlock()
- for element := m.connections.Front(); element != nil; element = element.Next() {
- common.Close(element.Value)
- }
- m.connections.Init()
- return nil
- }
- func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
- ctx = adapter.WithContext(ctx, &metadata)
- var (
- remoteConn net.Conn
- err error
- )
- if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
- remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
- } else {
- remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
- }
- if err != nil {
- var remoteString string
- if len(metadata.DestinationAddresses) > 0 {
- remoteString = "[" + strings.Join(common.Map(metadata.DestinationAddresses, netip.Addr.String), ",") + "]"
- } else {
- remoteString = metadata.Destination.String()
- }
- var dialerString string
- if outbound, isOutbound := this.(adapter.Outbound); isOutbound {
- dialerString = " using outbound/" + outbound.Type() + "[" + outbound.Tag() + "]"
- }
- err = E.Cause(err, "open connection to ", remoteString, dialerString)
- N.CloseOnHandshakeFailure(conn, onClose, err)
- m.logger.ErrorContext(ctx, err)
- return
- }
- err = N.ReportConnHandshakeSuccess(conn, remoteConn)
- if err != nil {
- err = E.Cause(err, "report handshake success")
- remoteConn.Close()
- N.CloseOnHandshakeFailure(conn, onClose, err)
- m.logger.ErrorContext(ctx, err)
- return
- }
- if metadata.TLSFragment || metadata.TLSRecordFragment {
- remoteConn = tf.NewConn(remoteConn, ctx, metadata.TLSFragment, metadata.TLSRecordFragment, metadata.TLSFragmentFallbackDelay)
- }
- m.access.Lock()
- element := m.connections.PushBack(conn)
- m.access.Unlock()
- onClose = N.AppendClose(onClose, func(it error) {
- m.access.Lock()
- defer m.access.Unlock()
- m.connections.Remove(element)
- })
- var done atomic.Bool
- m.preConnectionCopy(ctx, conn, remoteConn, false, &done, onClose)
- m.preConnectionCopy(ctx, remoteConn, conn, true, &done, onClose)
- go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
- go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
- }
- func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
- ctx = adapter.WithContext(ctx, &metadata)
- var (
- remotePacketConn net.PacketConn
- remoteConn net.Conn
- destinationAddress netip.Addr
- err error
- )
- if metadata.UDPConnect {
- parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer)
- if len(metadata.DestinationAddresses) > 0 {
- if isParallelDialer {
- remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
- } else {
- remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
- }
- } else if metadata.Destination.IsIP() {
- if isParallelDialer {
- remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
- } else {
- remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
- }
- } else {
- remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
- }
- if err != nil {
- var remoteString string
- if len(metadata.DestinationAddresses) > 0 {
- remoteString = "[" + strings.Join(common.Map(metadata.DestinationAddresses, netip.Addr.String), ",") + "]"
- } else {
- remoteString = metadata.Destination.String()
- }
- var dialerString string
- if outbound, isOutbound := this.(adapter.Outbound); isOutbound {
- dialerString = " using outbound/" + outbound.Type() + "[" + outbound.Tag() + "]"
- }
- err = E.Cause(err, "open packet connection to ", remoteString, dialerString)
- N.CloseOnHandshakeFailure(conn, onClose, err)
- m.logger.ErrorContext(ctx, err)
- return
- }
- remotePacketConn = bufio.NewUnbindPacketConn(remoteConn)
- connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr())
- if connRemoteAddr != metadata.Destination.Addr {
- destinationAddress = connRemoteAddr
- }
- } else {
- if len(metadata.DestinationAddresses) > 0 {
- remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
- } else {
- remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination)
- }
- if err != nil {
- var dialerString string
- if outbound, isOutbound := this.(adapter.Outbound); isOutbound {
- dialerString = " using outbound/" + outbound.Type() + "[" + outbound.Tag() + "]"
- }
- err = E.Cause(err, "listen packet connection using ", dialerString)
- N.CloseOnHandshakeFailure(conn, onClose, err)
- m.logger.ErrorContext(ctx, err)
- return
- }
- }
- err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn)
- if err != nil {
- conn.Close()
- remotePacketConn.Close()
- m.logger.ErrorContext(ctx, "report handshake success: ", err)
- return
- }
- if destinationAddress.IsValid() {
- var originDestination M.Socksaddr
- if metadata.RouteOriginalDestination.IsValid() {
- originDestination = metadata.RouteOriginalDestination
- } else {
- originDestination = metadata.Destination
- }
- if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
- natConn.UpdateDestination(destinationAddress)
- } else if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
- if metadata.UDPDisableDomainUnmapping {
- remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
- } else {
- remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
- }
- }
- } else if metadata.RouteOriginalDestination.IsValid() && metadata.RouteOriginalDestination != metadata.Destination {
- remotePacketConn = bufio.NewDestinationNATPacketConn(bufio.NewPacketConn(remotePacketConn), metadata.Destination, metadata.RouteOriginalDestination)
- }
- var udpTimeout time.Duration
- if metadata.UDPTimeout > 0 {
- udpTimeout = metadata.UDPTimeout
- } else {
- protocol := metadata.Protocol
- if protocol == "" {
- protocol = C.PortProtocols[metadata.Destination.Port]
- }
- if protocol != "" {
- udpTimeout = C.ProtocolTimeouts[protocol]
- }
- }
- if udpTimeout > 0 {
- ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout)
- }
- destination := bufio.NewPacketConn(remotePacketConn)
- m.access.Lock()
- element := m.connections.PushBack(conn)
- m.access.Unlock()
- onClose = N.AppendClose(onClose, func(it error) {
- m.access.Lock()
- defer m.access.Unlock()
- m.connections.Remove(element)
- })
- var done atomic.Bool
- go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
- go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
- }
- func (m *ConnectionManager) preConnectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
- readHandshake := N.NeedHandshakeForRead(source)
- writeHandshake := N.NeedHandshakeForWrite(destination)
- if readHandshake || writeHandshake {
- var err error
- for {
- err = m.connectionCopyEarlyWrite(source, destination, readHandshake, writeHandshake)
- if err == nil && N.NeedHandshakeForRead(source) {
- continue
- } else if E.IsMulti(err, os.ErrInvalid, context.DeadlineExceeded, io.EOF) {
- err = nil
- }
- break
- }
- if err != nil {
- if done.Swap(true) {
- onClose(err)
- }
- common.Close(source, destination)
- if !direction {
- m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
- } else {
- m.logger.ErrorContext(ctx, "connection download handshake: ", err)
- }
- return
- }
- }
- }
- func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
- var (
- sourceReader io.Reader = source
- destinationWriter io.Writer = destination
- )
- var readCounters, writeCounters []N.CountFunc
- for {
- sourceReader, readCounters = N.UnwrapCountReader(sourceReader, readCounters)
- destinationWriter, writeCounters = N.UnwrapCountWriter(destinationWriter, writeCounters)
- if cachedSrc, isCached := sourceReader.(N.CachedReader); isCached {
- cachedBuffer := cachedSrc.ReadCached()
- if cachedBuffer != nil {
- dataLen := cachedBuffer.Len()
- _, err := destination.Write(cachedBuffer.Bytes())
- cachedBuffer.Release()
- if err != nil {
- if done.Swap(true) {
- onClose(err)
- }
- common.Close(source, destination)
- if !direction {
- m.logger.ErrorContext(ctx, "connection upload payload: ", err)
- } else {
- m.logger.ErrorContext(ctx, "connection download payload: ", err)
- }
- return
- }
- for _, counter := range readCounters {
- counter(int64(dataLen))
- }
- for _, counter := range writeCounters {
- counter(int64(dataLen))
- }
- }
- continue
- }
- break
- }
- _, err := bufio.CopyWithCounters(destinationWriter, sourceReader, source, readCounters, writeCounters, bufio.DefaultIncreaseBufferAfter, bufio.DefaultBatchSize)
- if err != nil {
- common.Close(source, destination)
- } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
- err = duplexDst.CloseWrite()
- if err != nil {
- common.Close(source, destination)
- }
- } else {
- destination.Close()
- }
- if done.Swap(true) {
- onClose(err)
- common.Close(source, destination)
- }
- if !direction {
- if err == nil {
- m.logger.DebugContext(ctx, "connection upload finished")
- } else if !E.IsClosedOrCanceled(err) {
- m.logger.ErrorContext(ctx, "connection upload closed: ", err)
- } else {
- m.logger.TraceContext(ctx, "connection upload closed")
- }
- } else {
- if err == nil {
- m.logger.DebugContext(ctx, "connection download finished")
- } else if !E.IsClosedOrCanceled(err) {
- m.logger.ErrorContext(ctx, "connection download closed: ", err)
- } else {
- m.logger.TraceContext(ctx, "connection download closed")
- }
- }
- }
- func (m *ConnectionManager) connectionCopyEarlyWrite(source net.Conn, destination io.Writer, readHandshake bool, writeHandshake bool) error {
- payload := buf.NewPacket()
- defer payload.Release()
- err := source.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
- if err != nil {
- if err == os.ErrInvalid {
- if writeHandshake {
- return common.Error(destination.Write(nil))
- }
- }
- return err
- }
- var (
- isTimeout bool
- isEOF bool
- )
- _, err = payload.ReadOnceFrom(source)
- if err != nil {
- if E.IsTimeout(err) {
- isTimeout = true
- } else if errors.Is(err, io.EOF) {
- isEOF = true
- } else {
- return E.Cause(err, "read payload")
- }
- }
- _ = source.SetReadDeadline(time.Time{})
- if !payload.IsEmpty() || writeHandshake {
- _, err = destination.Write(payload.Bytes())
- if err != nil {
- return E.Cause(err, "write payload")
- }
- }
- if isTimeout {
- return context.DeadlineExceeded
- } else if isEOF {
- return io.EOF
- }
- return nil
- }
- func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
- _, err := bufio.CopyPacket(destination, source)
- if !direction {
- if err == nil {
- m.logger.DebugContext(ctx, "packet upload finished")
- } else if E.IsClosedOrCanceled(err) {
- m.logger.TraceContext(ctx, "packet upload closed")
- } else {
- m.logger.DebugContext(ctx, "packet upload closed: ", err)
- }
- } else {
- if err == nil {
- m.logger.DebugContext(ctx, "packet download finished")
- } else if E.IsClosedOrCanceled(err) {
- m.logger.TraceContext(ctx, "packet download closed")
- } else {
- m.logger.DebugContext(ctx, "packet download closed: ", err)
- }
- }
- if !done.Swap(true) {
- onClose(err)
- }
- common.Close(source, destination)
- }
|