Browse Source

Add interrupt support for outbound groups

世界 2 years ago
parent
commit
c320be75a7

+ 75 - 0
common/interrupt/conn.go

@@ -0,0 +1,75 @@
+package interrupt
+
+import (
+	"net"
+
+	"github.com/sagernet/sing/common/x/list"
+)
+
+/*type GroupedConn interface {
+	MarkAsInternal()
+}
+
+func MarkAsInternal(conn any) {
+	if groupedConn, isGroupConn := common.Cast[GroupedConn](conn); isGroupConn {
+		groupedConn.MarkAsInternal()
+	}
+}*/
+
+type Conn struct {
+	net.Conn
+	group   *Group
+	element *list.Element[*groupConnItem]
+}
+
+/*func (c *Conn) MarkAsInternal() {
+	c.element.Value.internal = true
+}*/
+
+func (c *Conn) Close() error {
+	c.group.access.Lock()
+	defer c.group.access.Unlock()
+	c.group.connections.Remove(c.element)
+	return c.Conn.Close()
+}
+
+func (c *Conn) ReaderReplaceable() bool {
+	return true
+}
+
+func (c *Conn) WriterReplaceable() bool {
+	return true
+}
+
+func (c *Conn) Upstream() any {
+	return c.Conn
+}
+
+type PacketConn struct {
+	net.PacketConn
+	group   *Group
+	element *list.Element[*groupConnItem]
+}
+
+/*func (c *PacketConn) MarkAsInternal() {
+	c.element.Value.internal = true
+}*/
+
+func (c *PacketConn) Close() error {
+	c.group.access.Lock()
+	defer c.group.access.Unlock()
+	c.group.connections.Remove(c.element)
+	return c.PacketConn.Close()
+}
+
+func (c *PacketConn) ReaderReplaceable() bool {
+	return true
+}
+
+func (c *PacketConn) WriterReplaceable() bool {
+	return true
+}
+
+func (c *PacketConn) Upstream() any {
+	return c.PacketConn
+}

+ 13 - 0
common/interrupt/context.go

@@ -0,0 +1,13 @@
+package interrupt
+
+import "context"
+
+type contextKeyIsExternalConnection struct{}
+
+func ContextWithIsExternalConnection(ctx context.Context) context.Context {
+	return context.WithValue(ctx, contextKeyIsExternalConnection{}, true)
+}
+
+func IsExternalConnectionFromContext(ctx context.Context) bool {
+	return ctx.Value(contextKeyIsExternalConnection{}) != nil
+}

+ 52 - 0
common/interrupt/group.go

@@ -0,0 +1,52 @@
+package interrupt
+
+import (
+	"io"
+	"net"
+	"sync"
+
+	"github.com/sagernet/sing/common/x/list"
+)
+
+type Group struct {
+	access      sync.Mutex
+	connections list.List[*groupConnItem]
+}
+
+type groupConnItem struct {
+	conn       io.Closer
+	isExternal bool
+}
+
+func NewGroup() *Group {
+	return &Group{}
+}
+
+func (g *Group) NewConn(conn net.Conn, isExternal bool) net.Conn {
+	g.access.Lock()
+	defer g.access.Unlock()
+	item := g.connections.PushBack(&groupConnItem{conn, isExternal})
+	return &Conn{Conn: conn, group: g, element: item}
+}
+
+func (g *Group) NewPacketConn(conn net.PacketConn, isExternal bool) net.PacketConn {
+	g.access.Lock()
+	defer g.access.Unlock()
+	item := g.connections.PushBack(&groupConnItem{conn, isExternal})
+	return &PacketConn{PacketConn: conn, group: g, element: item}
+}
+
+func (g *Group) Interrupt(interruptExternalConnections bool) {
+	g.access.Lock()
+	defer g.access.Unlock()
+	var toDelete []*list.Element[*groupConnItem]
+	for element := g.connections.Front(); element != nil; element = element.Next() {
+		if !element.Value.isExternal || interruptExternalConnections {
+			element.Value.conn.Close()
+			toDelete = append(toDelete, element)
+		}
+	}
+	for _, element := range toDelete {
+		g.connections.Remove(element)
+	}
+}

+ 9 - 2
docs/configuration/outbound/selector.md

@@ -10,7 +10,8 @@
     "proxy-b",
     "proxy-c"
   ],
-  "default": "proxy-c"
+  "default": "proxy-c",
+  "interrupt_exist_connections": false
 }
 ```
 
@@ -28,4 +29,10 @@ List of outbound tags to select.
 
 #### default
 
-The default outbound tag. The first outbound will be used if empty.
+The default outbound tag. The first outbound will be used if empty.
+
+#### interrupt_exist_connections
+
+Interrupt existing connections when the selected outbound has changed.
+
+Only inbound connections are affected by this setting, internal connections will always be interrupted.

+ 8 - 1
docs/configuration/outbound/selector.zh.md

@@ -10,7 +10,8 @@
     "proxy-b",
     "proxy-c"
   ],
-  "default": "proxy-c"
+  "default": "proxy-c",
+  "interrupt_exist_connections": false
 }
 ```
 
@@ -29,3 +30,9 @@
 #### default
 
 默认的出站标签。默认使用第一个出站。
+
+#### interrupt_exist_connections
+
+当选定的出站发生更改时,中断现有连接。
+
+仅入站连接受此设置影响,内部连接将始终被中断。

+ 8 - 1
docs/configuration/outbound/urltest.md

@@ -12,7 +12,8 @@
   ],
   "url": "https://www.gstatic.com/generate_204",
   "interval": "1m",
-  "tolerance": 50
+  "tolerance": 50,
+  "interrupt_exist_connections": false
 }
 ```
 
@@ -35,3 +36,9 @@ The test interval. `1m` will be used if empty.
 #### tolerance
 
 The test tolerance in milliseconds. `50` will be used if empty.
+
+#### interrupt_exist_connections
+
+Interrupt existing connections when the selected outbound has changed.
+
+Only inbound connections are affected by this setting, internal connections will always be interrupted.

+ 8 - 1
docs/configuration/outbound/urltest.zh.md

@@ -12,7 +12,8 @@
   ],
   "url": "https://www.gstatic.com/generate_204",
   "interval": "1m",
-  "tolerance": 50
+  "tolerance": 50,
+  "interrupt_exist_connections": false
 }
 ```
 
@@ -35,3 +36,9 @@
 #### tolerance
 
 以毫秒为单位的测试容差。 默认使用 `50`。
+
+#### interrupt_exist_connections
+
+当选定的出站发生更改时,中断现有连接。
+
+仅入站连接受此设置影响,内部连接将始终被中断。

+ 8 - 6
option/clash.go

@@ -17,13 +17,15 @@ type ClashAPIOptions struct {
 }
 
 type SelectorOutboundOptions struct {
-	Outbounds []string `json:"outbounds"`
-	Default   string   `json:"default,omitempty"`
+	Outbounds                 []string `json:"outbounds"`
+	Default                   string   `json:"default,omitempty"`
+	InterruptExistConnections bool     `json:"interrupt_exist_connections,omitempty"`
 }
 
 type URLTestOutboundOptions struct {
-	Outbounds []string `json:"outbounds"`
-	URL       string   `json:"url,omitempty"`
-	Interval  Duration `json:"interval,omitempty"`
-	Tolerance uint16   `json:"tolerance,omitempty"`
+	Outbounds                 []string `json:"outbounds"`
+	URL                       string   `json:"url,omitempty"`
+	Interval                  Duration `json:"interval,omitempty"`
+	Tolerance                 uint16   `json:"tolerance,omitempty"`
+	InterruptExistConnections bool     `json:"interrupt_exist_connections,omitempty"`
 }

+ 28 - 9
outbound/selector.go

@@ -5,6 +5,7 @@ import (
 	"net"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/interrupt"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
@@ -20,10 +21,12 @@ var (
 
 type Selector struct {
 	myOutboundAdapter
-	tags       []string
-	defaultTag string
-	outbounds  map[string]adapter.Outbound
-	selected   adapter.Outbound
+	tags                         []string
+	defaultTag                   string
+	outbounds                    map[string]adapter.Outbound
+	selected                     adapter.Outbound
+	interruptGroup               *interrupt.Group
+	interruptExternalConnections bool
 }
 
 func NewSelector(router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (*Selector, error) {
@@ -35,9 +38,11 @@ func NewSelector(router adapter.Router, logger log.ContextLogger, tag string, op
 			tag:          tag,
 			dependencies: options.Outbounds,
 		},
-		tags:       options.Outbounds,
-		defaultTag: options.Default,
-		outbounds:  make(map[string]adapter.Outbound),
+		tags:                         options.Outbounds,
+		defaultTag:                   options.Default,
+		outbounds:                    make(map[string]adapter.Outbound),
+		interruptGroup:               interrupt.NewGroup(),
+		interruptExternalConnections: options.InterruptExistConnections,
 	}
 	if len(outbound.tags) == 0 {
 		return nil, E.New("missing tags")
@@ -100,6 +105,9 @@ func (s *Selector) SelectOutbound(tag string) bool {
 	if !loaded {
 		return false
 	}
+	if s.selected == detour {
+		return true
+	}
 	s.selected = detour
 	if s.tag != "" {
 		if clashServer := s.router.ClashServer(); clashServer != nil && clashServer.StoreSelected() {
@@ -109,22 +117,33 @@ func (s *Selector) SelectOutbound(tag string) bool {
 			}
 		}
 	}
+	s.interruptGroup.Interrupt(s.interruptExternalConnections)
 	return true
 }
 
 func (s *Selector) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
-	return s.selected.DialContext(ctx, network, destination)
+	conn, err := s.selected.DialContext(ctx, network, destination)
+	if err != nil {
+		return nil, err
+	}
+	return s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
 }
 
 func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
-	return s.selected.ListenPacket(ctx, destination)
+	conn, err := s.selected.ListenPacket(ctx, destination)
+	if err != nil {
+		return nil, err
+	}
+	return s.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
 }
 
 func (s *Selector) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
+	ctx = interrupt.ContextWithIsExternalConnection(ctx)
 	return s.selected.NewConnection(ctx, conn, metadata)
 }
 
 func (s *Selector) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
+	ctx = interrupt.ContextWithIsExternalConnection(ctx)
 	return s.selected.NewPacketConnection(ctx, conn, metadata)
 }
 

+ 73 - 35
outbound/urltest.go

@@ -8,6 +8,7 @@ import (
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/interrupt"
 	"github.com/sagernet/sing-box/common/urltest"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
@@ -30,12 +31,13 @@ var (
 
 type URLTest struct {
 	myOutboundAdapter
-	ctx       context.Context
-	tags      []string
-	link      string
-	interval  time.Duration
-	tolerance uint16
-	group     *URLTestGroup
+	ctx                          context.Context
+	tags                         []string
+	link                         string
+	interval                     time.Duration
+	tolerance                    uint16
+	group                        *URLTestGroup
+	interruptExternalConnections bool
 }
 
 func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (*URLTest, error) {
@@ -47,11 +49,12 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo
 			tag:          tag,
 			dependencies: options.Outbounds,
 		},
-		ctx:       ctx,
-		tags:      options.Outbounds,
-		link:      options.URL,
-		interval:  time.Duration(options.Interval),
-		tolerance: options.Tolerance,
+		ctx:                          ctx,
+		tags:                         options.Outbounds,
+		link:                         options.URL,
+		interval:                     time.Duration(options.Interval),
+		tolerance:                    options.Tolerance,
+		interruptExternalConnections: options.InterruptExistConnections,
 	}
 	if len(outbound.tags) == 0 {
 		return nil, E.New("missing tags")
@@ -75,7 +78,7 @@ func (s *URLTest) Start() error {
 		}
 		outbounds = append(outbounds, detour)
 	}
-	s.group = NewURLTestGroup(s.ctx, s.router, s.logger, outbounds, s.link, s.interval, s.tolerance)
+	s.group = NewURLTestGroup(s.ctx, s.router, s.logger, outbounds, s.link, s.interval, s.tolerance, s.interruptExternalConnections)
 	return nil
 }
 
@@ -111,7 +114,7 @@ func (s *URLTest) DialContext(ctx context.Context, network string, destination M
 	outbound := s.group.Select(network)
 	conn, err := outbound.DialContext(ctx, network, destination)
 	if err == nil {
-		return conn, nil
+		return s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
 	}
 	s.logger.ErrorContext(ctx, err)
 	s.group.history.DeleteURLTestHistory(outbound.Tag())
@@ -123,7 +126,7 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne
 	outbound := s.group.Select(N.NetworkUDP)
 	conn, err := outbound.ListenPacket(ctx, destination)
 	if err == nil {
-		return conn, nil
+		return s.group.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil
 	}
 	s.logger.ErrorContext(ctx, err)
 	s.group.history.DeleteURLTestHistory(outbound.Tag())
@@ -131,10 +134,12 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne
 }
 
 func (s *URLTest) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
+	ctx = interrupt.ContextWithIsExternalConnection(ctx)
 	return NewConnection(ctx, s, conn, metadata)
 }
 
 func (s *URLTest) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
+	ctx = interrupt.ContextWithIsExternalConnection(ctx)
 	return NewPacketConnection(ctx, s, conn, metadata)
 }
 
@@ -144,23 +149,36 @@ func (s *URLTest) InterfaceUpdated() {
 }
 
 type URLTestGroup struct {
-	ctx          context.Context
-	router       adapter.Router
-	logger       log.Logger
-	outbounds    []adapter.Outbound
-	link         string
-	interval     time.Duration
-	tolerance    uint16
-	history      *urltest.HistoryStorage
-	checking     atomic.Bool
-	pauseManager pause.Manager
+	ctx                          context.Context
+	router                       adapter.Router
+	logger                       log.Logger
+	outbounds                    []adapter.Outbound
+	link                         string
+	interval                     time.Duration
+	tolerance                    uint16
+	history                      *urltest.HistoryStorage
+	checking                     atomic.Bool
+	pauseManager                 pause.Manager
+	selectedOutboundTCP          adapter.Outbound
+	selectedOutboundUDP          adapter.Outbound
+	interruptGroup               *interrupt.Group
+	interruptExternalConnections bool
 
 	access sync.Mutex
 	ticker *time.Ticker
 	close  chan struct{}
 }
 
-func NewURLTestGroup(ctx context.Context, router adapter.Router, logger log.Logger, outbounds []adapter.Outbound, link string, interval time.Duration, tolerance uint16) *URLTestGroup {
+func NewURLTestGroup(
+	ctx context.Context,
+	router adapter.Router,
+	logger log.Logger,
+	outbounds []adapter.Outbound,
+	link string,
+	interval time.Duration,
+	tolerance uint16,
+	interruptExternalConnections bool,
+) *URLTestGroup {
 	if interval == 0 {
 		interval = C.DefaultURLTestInterval
 	}
@@ -175,16 +193,18 @@ func NewURLTestGroup(ctx context.Context, router adapter.Router, logger log.Logg
 		history = urltest.NewHistoryStorage()
 	}
 	return &URLTestGroup{
-		ctx:          ctx,
-		router:       router,
-		logger:       logger,
-		outbounds:    outbounds,
-		link:         link,
-		interval:     interval,
-		tolerance:    tolerance,
-		history:      history,
-		close:        make(chan struct{}),
-		pauseManager: pause.ManagerFromContext(ctx),
+		ctx:                          ctx,
+		router:                       router,
+		logger:                       logger,
+		outbounds:                    outbounds,
+		link:                         link,
+		interval:                     interval,
+		tolerance:                    tolerance,
+		history:                      history,
+		close:                        make(chan struct{}),
+		pauseManager:                 pause.ManagerFromContext(ctx),
+		interruptGroup:               interrupt.NewGroup(),
+		interruptExternalConnections: interruptExternalConnections,
 	}
 }
 
@@ -329,5 +349,23 @@ func (g *URLTestGroup) urlTest(ctx context.Context, link string, force bool) (ma
 		})
 	}
 	b.Wait()
+	g.performUpdateCheck()
 	return result, nil
 }
+
+func (g *URLTestGroup) performUpdateCheck() {
+	outbound := g.Select(N.NetworkTCP)
+	var updated bool
+	if outbound != g.selectedOutboundTCP {
+		g.selectedOutboundTCP = outbound
+		updated = true
+	}
+	outbound = g.Select(N.NetworkUDP)
+	if outbound != g.selectedOutboundUDP {
+		g.selectedOutboundUDP = outbound
+		updated = true
+	}
+	if updated {
+		g.interruptGroup.Interrupt(g.interruptExternalConnections)
+	}
+}