Ver código fonte

URLTest improvements

世界 2 anos atrás
pai
commit
b491c350ae
1 arquivos alterados com 41 adições e 11 exclusões
  1. 41 11
      outbound/urltest.go

+ 41 - 11
outbound/urltest.go

@@ -13,6 +13,7 @@ import (
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/atomic"
 	"github.com/sagernet/sing/common/batch"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
@@ -20,8 +21,9 @@ import (
 )
 
 var (
-	_ adapter.Outbound      = (*URLTest)(nil)
-	_ adapter.OutboundGroup = (*URLTest)(nil)
+	_ adapter.Outbound                = (*URLTest)(nil)
+	_ adapter.OutboundGroup           = (*URLTest)(nil)
+	_ adapter.InterfaceUpdateListener = (*URLTest)(nil)
 )
 
 type URLTest struct {
@@ -71,7 +73,8 @@ func (s *URLTest) Start() error {
 		outbounds = append(outbounds, detour)
 	}
 	s.group = NewURLTestGroup(s.ctx, s.router, s.logger, outbounds, s.link, s.interval, s.tolerance)
-	return s.group.Start()
+	go s.group.CheckOutbounds(false)
+	return nil
 }
 
 func (s *URLTest) Close() error {
@@ -93,6 +96,7 @@ func (s *URLTest) URLTest(ctx context.Context, link string) (map[string]uint16,
 }
 
 func (s *URLTest) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
+	s.group.Start()
 	outbound := s.group.Select(network)
 	conn, err := outbound.DialContext(ctx, network, destination)
 	if err == nil {
@@ -104,6 +108,7 @@ func (s *URLTest) DialContext(ctx context.Context, network string, destination M
 }
 
 func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
+	s.group.Start()
 	outbound := s.group.Select(N.NetworkUDP)
 	conn, err := outbound.ListenPacket(ctx, destination)
 	if err == nil {
@@ -122,6 +127,11 @@ func (s *URLTest) NewPacketConnection(ctx context.Context, conn N.PacketConn, me
 	return NewPacketConnection(ctx, s, conn, metadata)
 }
 
+func (s *URLTest) InterfaceUpdated() error {
+	go s.group.CheckOutbounds(true)
+	return nil
+}
+
 type URLTestGroup struct {
 	ctx       context.Context
 	router    adapter.Router
@@ -131,7 +141,9 @@ type URLTestGroup struct {
 	interval  time.Duration
 	tolerance uint16
 	history   *urltest.HistoryStorage
+	checking  atomic.Bool
 
+	access sync.Mutex
 	ticker *time.Ticker
 	close  chan struct{}
 }
@@ -162,13 +174,23 @@ func NewURLTestGroup(ctx context.Context, router adapter.Router, logger log.Logg
 	}
 }
 
-func (g *URLTestGroup) Start() error {
+func (g *URLTestGroup) Start() {
+	if g.ticker != nil {
+		return
+	}
+	g.access.Lock()
+	defer g.access.Unlock()
+	if g.ticker != nil {
+		return
+	}
 	g.ticker = time.NewTicker(g.interval)
 	go g.loopCheck()
-	return nil
 }
 
 func (g *URLTestGroup) Close() error {
+	if g.ticker == nil {
+		return nil
+	}
 	g.ticker.Stop()
 	close(g.close)
 	return nil
@@ -228,25 +250,33 @@ func (g *URLTestGroup) Fallback(used adapter.Outbound) []adapter.Outbound {
 }
 
 func (g *URLTestGroup) loopCheck() {
-	go g.checkOutbounds()
+	go g.CheckOutbounds(true)
 	for {
 		select {
 		case <-g.close:
 			return
 		case <-g.ticker.C:
-			g.checkOutbounds()
+			g.CheckOutbounds(false)
 		}
 	}
 }
 
-func (g *URLTestGroup) checkOutbounds() {
-	_, _ = g.URLTest(g.ctx, g.link)
+func (g *URLTestGroup) CheckOutbounds(force bool) {
+	_, _ = g.urlTest(g.ctx, g.link, force)
 }
 
 func (g *URLTestGroup) URLTest(ctx context.Context, link string) (map[string]uint16, error) {
+	return g.urlTest(ctx, link, false)
+}
+
+func (g *URLTestGroup) urlTest(ctx context.Context, link string, force bool) (map[string]uint16, error) {
+	result := make(map[string]uint16)
+	if g.checking.Swap(true) {
+		return result, nil
+	}
+	defer g.checking.Store(false)
 	b, _ := batch.New(ctx, batch.WithConcurrencyNum[any](10))
 	checked := make(map[string]bool)
-	result := make(map[string]uint16)
 	var resultAccess sync.Mutex
 	for _, detour := range g.outbounds {
 		tag := detour.Tag()
@@ -255,7 +285,7 @@ func (g *URLTestGroup) URLTest(ctx context.Context, link string) (map[string]uin
 			continue
 		}
 		history := g.history.LoadURLTestHistory(realTag)
-		if history != nil && time.Now().Sub(history.Time) < g.interval {
+		if !force && history != nil && time.Now().Sub(history.Time) < g.interval {
 			continue
 		}
 		checked[realTag] = true