Browse Source

Improve timeouts

世界 10 months ago
parent
commit
22bda86bbf
4 changed files with 104 additions and 345 deletions
  1. 7 1
      protocol/wireguard/endpoint.go
  2. 97 173
      route/conn.go
  3. 0 128
      route/conn_monitor.go
  4. 0 43
      route/conn_monitor_test.go

+ 7 - 1
protocol/wireguard/endpoint.go

@@ -56,12 +56,18 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL
 	if err != nil {
 		return nil, err
 	}
+	var udpTimeout time.Duration
+	if options.UDPTimeout != 0 {
+		udpTimeout = time.Duration(options.UDPTimeout)
+	} else {
+		udpTimeout = C.UDPTimeout
+	}
 	wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{
 		Context:    ctx,
 		Logger:     logger,
 		System:     options.System,
 		Handler:    ep,
-		UDPTimeout: time.Duration(options.UDPTimeout),
+		UDPTimeout: udpTimeout,
 		Dialer:     outboundDialer,
 		CreateDialer: func(interfaceName string) N.Dialer {
 			return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{

+ 97 - 173
route/conn.go

@@ -5,6 +5,7 @@ import (
 	"io"
 	"net"
 	"net/netip"
+	"sync"
 	"sync/atomic"
 	"time"
 
@@ -18,31 +19,35 @@ import (
 	"github.com/sagernet/sing/common/logger"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
+	"github.com/sagernet/sing/common/x/list"
 )
 
 var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
 
 type ConnectionManager struct {
-	logger  logger.ContextLogger
-	monitor *ConnectionMonitor
+	logger      logger.ContextLogger
+	access      sync.Mutex
+	connections list.List[io.Closer]
 }
 
 func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
 	return &ConnectionManager{
-		logger:  logger,
-		monitor: NewConnectionMonitor(),
+		logger: logger,
 	}
 }
 
 func (m *ConnectionManager) Start(stage adapter.StartStage) error {
-	if stage != adapter.StartStateInitialize {
-		return nil
-	}
-	return m.monitor.Start()
+	return nil
 }
 
 func (m *ConnectionManager) Close() error {
-	return m.monitor.Close()
+	m.access.Lock()
+	defer m.access.Unlock()
+	for element := m.connections.Front(); element != nil; element = element.Next() {
+		common.Close(element.Value)
+	}
+	m.connections.Init()
+	return nil
 }
 
 func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
@@ -57,95 +62,32 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co
 		remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
 	}
 	if err != nil {
+		err = E.Cause(err, "open outbound connection")
 		N.CloseOnHandshakeFailure(conn, onClose, err)
-		m.logger.ErrorContext(ctx, "open outbound connection: ", err)
+		m.logger.ErrorContext(ctx, err)
 		return
 	}
 	err = N.ReportConnHandshakeSuccess(conn, remoteConn)
 	if err != nil {
+		err = E.Cause(err, "report handshake success")
 		remoteConn.Close()
 		N.CloseOnHandshakeFailure(conn, onClose, err)
-		m.logger.ErrorContext(ctx, "report handshake success: ", err)
+		m.logger.ErrorContext(ctx, err)
 		return
 	}
+	m.access.Lock()
+	element := m.connections.PushBack(conn)
+	m.access.Unlock()
+	onClose = N.AppendClose(onClose, func(it error) {
+		m.access.Lock()
+		defer m.access.Unlock()
+		m.connections.Remove(element)
+	})
 	var done atomic.Bool
-	if ctx.Done() != nil {
-		onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
-	}
 	go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
 	go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
 }
 
-func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
-	originSource := source
-	originDestination := destination
-	var readCounters, writeCounters []N.CountFunc
-	for {
-		source, readCounters = N.UnwrapCountReader(source, readCounters)
-		destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
-		if cachedSrc, isCached := source.(N.CachedReader); isCached {
-			cachedBuffer := cachedSrc.ReadCached()
-			if cachedBuffer != nil {
-				dataLen := cachedBuffer.Len()
-				_, err := destination.Write(cachedBuffer.Bytes())
-				cachedBuffer.Release()
-				if err != nil {
-					m.logger.ErrorContext(ctx, "connection upload payload: ", err)
-					if done.Swap(true) {
-						if onClose != nil {
-							onClose(err)
-						}
-					}
-					common.Close(originSource, originDestination)
-					return
-				}
-				for _, counter := range readCounters {
-					counter(int64(dataLen))
-				}
-				for _, counter := range writeCounters {
-					counter(int64(dataLen))
-				}
-			}
-			continue
-		}
-		break
-	}
-	_, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
-	if err != nil {
-		common.Close(originSource, originDestination)
-	} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
-		err = duplexDst.CloseWrite()
-		if err != nil {
-			common.Close(originSource, originDestination)
-		}
-	} else {
-		common.Close(originDestination)
-	}
-	if done.Swap(true) {
-		if onClose != nil {
-			onClose(err)
-		}
-		common.Close(originSource, originDestination)
-	}
-	if !direction {
-		if err == nil {
-			m.logger.DebugContext(ctx, "connection upload finished")
-		} else if !E.IsClosedOrCanceled(err) {
-			m.logger.ErrorContext(ctx, "connection upload closed: ", err)
-		} else {
-			m.logger.TraceContext(ctx, "connection upload closed")
-		}
-	} else {
-		if err == nil {
-			m.logger.DebugContext(ctx, "connection download finished")
-		} else if !E.IsClosedOrCanceled(err) {
-			m.logger.ErrorContext(ctx, "connection download closed: ", err)
-		} else {
-			m.logger.TraceContext(ctx, "connection download closed")
-		}
-	}
-}
-
 func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
 	ctx = adapter.WithContext(ctx, &metadata)
 	var (
@@ -227,58 +169,91 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
 		ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout)
 	}
 	destination := bufio.NewPacketConn(remotePacketConn)
+	m.access.Lock()
+	element := m.connections.PushBack(conn)
+	m.access.Unlock()
+	onClose = N.AppendClose(onClose, func(it error) {
+		m.access.Lock()
+		defer m.access.Unlock()
+		m.connections.Remove(element)
+	})
 	var done atomic.Bool
-	if ctx.Done() != nil {
-		onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
-	}
 	go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
 	go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
 }
 
-func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
-	_, err := bufio.CopyPacket(destination, source)
-	/*var readCounters, writeCounters []N.CountFunc
-	var cachedPackets []*N.PacketBuffer
+func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
 	originSource := source
+	originDestination := destination
+	var readCounters, writeCounters []N.CountFunc
 	for {
-		source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
-		destination, writeCounters = N.UnwrapCountPacketWriter(destination, writeCounters)
-		if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
-			packet := cachedReader.ReadCachedPacket()
-			if packet != nil {
-				cachedPackets = append(cachedPackets, packet)
-				continue
+		source, readCounters = N.UnwrapCountReader(source, readCounters)
+		destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
+		if cachedSrc, isCached := source.(N.CachedReader); isCached {
+			cachedBuffer := cachedSrc.ReadCached()
+			if cachedBuffer != nil {
+				dataLen := cachedBuffer.Len()
+				_, err := destination.Write(cachedBuffer.Bytes())
+				cachedBuffer.Release()
+				if err != nil {
+					if done.Swap(true) {
+						onClose(err)
+					}
+					common.Close(originSource, originDestination)
+					if !direction {
+						m.logger.ErrorContext(ctx, "connection upload payload: ", err)
+					} else {
+						m.logger.ErrorContext(ctx, "connection download payload: ", err)
+					}
+					return
+				}
+				for _, counter := range readCounters {
+					counter(int64(dataLen))
+				}
+				for _, counter := range writeCounters {
+					counter(int64(dataLen))
+				}
 			}
+			continue
 		}
 		break
 	}
-	var handled bool
-	if natConn, isNatConn := source.(udpnat.Conn); isNatConn {
-		natConn.SetHandler(&udpHijacker{
-			ctx:           ctx,
-			logger:        m.logger,
-			source:        natConn,
-			destination:   destination,
-			direction:     direction,
-			readCounters:  readCounters,
-			writeCounters: writeCounters,
-			done:          done,
-			onClose:       onClose,
-		})
-		handled = true
-	}
-	if cachedPackets != nil {
-		_, err := bufio.WritePacketWithPool(originSource, destination, cachedPackets, readCounters, writeCounters)
+	_, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
+	if err != nil {
+		common.Close(originDestination)
+	} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
+		err = duplexDst.CloseWrite()
 		if err != nil {
-			common.Close(source, destination)
-			m.logger.ErrorContext(ctx, "packet upload payload: ", err)
-			return
+			common.Close(originSource, originDestination)
 		}
+	} else {
+		common.Close(originDestination)
 	}
-	if handled {
-		return
+	if done.Swap(true) {
+		onClose(err)
+		common.Close(originSource, originDestination)
+	}
+	if !direction {
+		if err == nil {
+			m.logger.DebugContext(ctx, "connection upload finished")
+		} else if !E.IsClosedOrCanceled(err) {
+			m.logger.ErrorContext(ctx, "connection upload closed: ", err)
+		} else {
+			m.logger.TraceContext(ctx, "connection upload closed")
+		}
+	} else {
+		if err == nil {
+			m.logger.DebugContext(ctx, "connection download finished")
+		} else if !E.IsClosedOrCanceled(err) {
+			m.logger.ErrorContext(ctx, "connection download closed: ", err)
+		} else {
+			m.logger.TraceContext(ctx, "connection download closed")
+		}
 	}
-	_, err := bufio.CopyPacketWithCounters(destination, source, originSource, readCounters, writeCounters)*/
+}
+
+func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
+	_, err := bufio.CopyPacket(destination, source)
 	if !direction {
 		if E.IsClosedOrCanceled(err) {
 			m.logger.TraceContext(ctx, "packet upload closed")
@@ -293,58 +268,7 @@ func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.P
 		}
 	}
 	if !done.Swap(true) {
-		if onClose != nil {
-			onClose(err)
-		}
+		onClose(err)
 	}
 	common.Close(source, destination)
 }
-
-/*type udpHijacker struct {
-	ctx           context.Context
-	logger        logger.ContextLogger
-	source        io.Closer
-	destination   N.PacketWriter
-	direction     bool
-	readCounters  []N.CountFunc
-	writeCounters []N.CountFunc
-	done          *atomic.Bool
-	onClose       N.CloseHandlerFunc
-}
-
-func (u *udpHijacker) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
-	dataLen := buffer.Len()
-	for _, counter := range u.readCounters {
-		counter(int64(dataLen))
-	}
-	err := u.destination.WritePacket(buffer, source)
-	if err != nil {
-		common.Close(u.source, u.destination)
-		u.logger.DebugContext(u.ctx, "packet upload closed: ", err)
-		return
-	}
-	for _, counter := range u.writeCounters {
-		counter(int64(dataLen))
-	}
-}
-
-func (u *udpHijacker) Close() error {
-	var err error
-	if !u.done.Swap(true) {
-		err = common.Close(u.source, u.destination)
-		if u.onClose != nil {
-			u.onClose(net.ErrClosed)
-		}
-	}
-	if u.direction {
-		u.logger.TraceContext(u.ctx, "packet  download closed")
-	} else {
-		u.logger.TraceContext(u.ctx, "packet upload closed")
-	}
-	return err
-}
-
-func (u *udpHijacker) Upstream() any {
-	return u.destination
-}
-*/

+ 0 - 128
route/conn_monitor.go

@@ -1,128 +0,0 @@
-package route
-
-import (
-	"context"
-	"io"
-	"reflect"
-	"sync"
-	"time"
-
-	N "github.com/sagernet/sing/common/network"
-	"github.com/sagernet/sing/common/x/list"
-)
-
-type ConnectionMonitor struct {
-	access      sync.RWMutex
-	reloadChan  chan struct{}
-	connections list.List[*monitorEntry]
-}
-
-type monitorEntry struct {
-	ctx    context.Context
-	closer io.Closer
-}
-
-func NewConnectionMonitor() *ConnectionMonitor {
-	return &ConnectionMonitor{
-		reloadChan: make(chan struct{}, 1),
-	}
-}
-
-func (m *ConnectionMonitor) Add(ctx context.Context, closer io.Closer) N.CloseHandlerFunc {
-	m.access.Lock()
-	defer m.access.Unlock()
-	element := m.connections.PushBack(&monitorEntry{
-		ctx:    ctx,
-		closer: closer,
-	})
-	select {
-	case <-m.reloadChan:
-		return nil
-	default:
-		select {
-		case m.reloadChan <- struct{}{}:
-		default:
-		}
-	}
-	return func(it error) {
-		m.access.Lock()
-		defer m.access.Unlock()
-		m.connections.Remove(element)
-		select {
-		case <-m.reloadChan:
-		default:
-			select {
-			case m.reloadChan <- struct{}{}:
-			default:
-			}
-		}
-	}
-}
-
-func (m *ConnectionMonitor) Start() error {
-	go m.monitor()
-	return nil
-}
-
-func (m *ConnectionMonitor) Close() error {
-	m.access.Lock()
-	defer m.access.Unlock()
-	close(m.reloadChan)
-	for element := m.connections.Front(); element != nil; element = element.Next() {
-		element.Value.closer.Close()
-	}
-	return nil
-}
-
-func (m *ConnectionMonitor) monitor() {
-	var (
-		selectCases []reflect.SelectCase
-		elements    []*list.Element[*monitorEntry]
-	)
-	rootCase := reflect.SelectCase{
-		Dir:  reflect.SelectRecv,
-		Chan: reflect.ValueOf(m.reloadChan),
-	}
-	for {
-		m.access.RLock()
-		if m.connections.Len() == 0 {
-			m.access.RUnlock()
-			if _, loaded := <-m.reloadChan; !loaded {
-				return
-			} else {
-				continue
-			}
-		}
-		if len(elements) < m.connections.Len() {
-			elements = make([]*list.Element[*monitorEntry], 0, m.connections.Len())
-		}
-		if len(selectCases) < m.connections.Len()+1 {
-			selectCases = make([]reflect.SelectCase, 0, m.connections.Len()+1)
-		}
-		elements = elements[:0]
-		selectCases = selectCases[:1]
-		selectCases[0] = rootCase
-		for element := m.connections.Front(); element != nil; element = element.Next() {
-			elements = append(elements, element)
-			selectCases = append(selectCases, reflect.SelectCase{
-				Dir:  reflect.SelectRecv,
-				Chan: reflect.ValueOf(element.Value.ctx.Done()),
-			})
-		}
-		m.access.RUnlock()
-		selected, _, loaded := reflect.Select(selectCases)
-		if selected == 0 {
-			if !loaded {
-				return
-			} else {
-				time.Sleep(time.Second)
-				continue
-			}
-		}
-		element := elements[selected-1]
-		m.access.Lock()
-		m.connections.Remove(element)
-		m.access.Unlock()
-		element.Value.closer.Close() // maybe go close
-	}
-}

+ 0 - 43
route/conn_monitor_test.go

@@ -1,43 +0,0 @@
-package route_test
-
-import (
-	"context"
-	"sync"
-	"testing"
-	"time"
-
-	"github.com/sagernet/sing-box/route"
-
-	"github.com/stretchr/testify/require"
-)
-
-func TestMonitor(t *testing.T) {
-	t.Parallel()
-	var closer myCloser
-	closer.Add(1)
-	monitor := route.NewConnectionMonitor()
-	require.NoError(t, monitor.Start())
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
-	monitor.Add(ctx, &closer)
-	done := make(chan struct{})
-	go func() {
-		closer.Wait()
-		close(done)
-	}()
-	select {
-	case <-done:
-	case <-time.After(time.Second + 100*time.Millisecond):
-		t.Fatal("timeout")
-	}
-	cancel()
-	require.NoError(t, monitor.Close())
-}
-
-type myCloser struct {
-	sync.WaitGroup
-}
-
-func (c *myCloser) Close() error {
-	c.Done()
-	return nil
-}