Browse Source

Add reload platform command

世界 2 years ago
parent
commit
b9b2b77814

+ 2 - 2
cmd/internal/build_libbox/main.go

@@ -86,10 +86,10 @@ func buildiOS() {
 	if !debugEnabled {
 		args = append(
 			args, "-trimpath", "-ldflags=-s -w -buildid=",
-			"-tags", "with_gvisor,with_utls,with_clash_api",
+			"-tags", "with_gvisor,with_utls,with_clash_api,with_conntrack",
 		)
 	} else {
-		args = append(args, "-tags", "with_gvisor,with_utls,with_clash_api,debug")
+		args = append(args, "-tags", "with_gvisor,with_utls,with_clash_api,with_conntrack,debug")
 	}
 
 	args = append(args, "./experimental/libbox")

+ 51 - 0
common/dialer/conntrack/conn.go

@@ -0,0 +1,51 @@
+package conntrack
+
+import (
+	"net"
+	"runtime/debug"
+
+	"github.com/sagernet/sing/common/x/list"
+)
+
+type Conn struct {
+	net.Conn
+	element *list.Element[*ConnEntry]
+}
+
+func NewConn(conn net.Conn) *Conn {
+	entry := &ConnEntry{
+		Conn:  conn,
+		Stack: debug.Stack(),
+	}
+	connAccess.Lock()
+	element := openConnection.PushBack(entry)
+	connAccess.Unlock()
+	return &Conn{
+		Conn:    conn,
+		element: element,
+	}
+}
+
+func (c *Conn) Close() error {
+	if c.element.Value != nil {
+		connAccess.Lock()
+		if c.element.Value != nil {
+			openConnection.Remove(c.element)
+			c.element.Value = nil
+		}
+		connAccess.Unlock()
+	}
+	return c.Conn.Close()
+}
+
+func (c *Conn) Upstream() any {
+	return c.Conn
+}
+
+func (c *Conn) ReaderReplaceable() bool {
+	return true
+}
+
+func (c *Conn) WriterReplaceable() bool {
+	return true
+}

+ 51 - 0
common/dialer/conntrack/packet_conn.go

@@ -0,0 +1,51 @@
+package conntrack
+
+import (
+	"net"
+	"runtime/debug"
+
+	"github.com/sagernet/sing/common/x/list"
+)
+
+type PacketConn struct {
+	net.PacketConn
+	element *list.Element[*ConnEntry]
+}
+
+func NewPacketConn(conn net.PacketConn) *PacketConn {
+	entry := &ConnEntry{
+		Conn:  conn,
+		Stack: debug.Stack(),
+	}
+	connAccess.Lock()
+	element := openConnection.PushBack(entry)
+	connAccess.Unlock()
+	return &PacketConn{
+		PacketConn: conn,
+		element:    element,
+	}
+}
+
+func (c *PacketConn) Close() error {
+	if c.element.Value != nil {
+		connAccess.Lock()
+		if c.element.Value != nil {
+			openConnection.Remove(c.element)
+			c.element.Value = nil
+		}
+		connAccess.Unlock()
+	}
+	return c.PacketConn.Close()
+}
+
+func (c *PacketConn) Upstream() any {
+	return c.PacketConn
+}
+
+func (c *PacketConn) ReaderReplaceable() bool {
+	return true
+}
+
+func (c *PacketConn) WriterReplaceable() bool {
+	return true
+}

+ 43 - 0
common/dialer/conntrack/track.go

@@ -0,0 +1,43 @@
+package conntrack
+
+import (
+	"io"
+	"sync"
+
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/x/list"
+)
+
+var (
+	connAccess     sync.RWMutex
+	openConnection list.List[*ConnEntry]
+)
+
+type ConnEntry struct {
+	Conn  io.Closer
+	Stack []byte
+}
+
+func Count() int {
+	return openConnection.Len()
+}
+
+func List() []*ConnEntry {
+	connAccess.RLock()
+	defer connAccess.RUnlock()
+	connList := make([]*ConnEntry, 0, openConnection.Len())
+	for element := openConnection.Front(); element != nil; element = element.Next() {
+		connList = append(connList, element.Value)
+	}
+	return connList
+}
+
+func Close() {
+	connAccess.Lock()
+	defer connAccess.Unlock()
+	for element := openConnection.Front(); element != nil; element = element.Next() {
+		common.Close(element.Value.Conn)
+		element.Value = nil
+	}
+	openConnection = list.List[*ConnEntry]{}
+}

+ 5 - 0
common/dialer/conntrack/track_disable.go

@@ -0,0 +1,5 @@
+//go:build !with_conntrack
+
+package conntrack
+
+const Enabled = false

+ 5 - 0
common/dialer/conntrack/track_enable.go

@@ -0,0 +1,5 @@
+//go:build with_conntrack
+
+package conntrack
+
+const Enabled = true

+ 19 - 4
common/dialer/default.go

@@ -6,6 +6,7 @@ import (
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer/conntrack"
 	"github.com/sagernet/sing-box/common/warning"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/option"
@@ -159,16 +160,30 @@ func (d *DefaultDialer) DialContext(ctx context.Context, network string, address
 		}
 	}
 	if !address.IsIPv6() {
-		return DialSlowContext(&d.dialer4, ctx, network, address)
+		return trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
 	} else {
-		return DialSlowContext(&d.dialer6, ctx, network, address)
+		return trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
 	}
 }
 
 func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
 	if !destination.IsIPv6() {
-		return d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4)
+		return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
 	} else {
-		return d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6)
+		return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
 	}
 }
+
+func trackConn(conn net.Conn, err error) (net.Conn, error) {
+	if !conntrack.Enabled || err != nil {
+		return conn, err
+	}
+	return conntrack.NewConn(conn), nil
+}
+
+func trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
+	if !conntrack.Enabled || err != nil {
+		return conn, err
+	}
+	return conntrack.NewPacketConn(conn), nil
+}

+ 2 - 0
experimental/libbox/command.go

@@ -5,4 +5,6 @@ package libbox
 const (
 	CommandLog int32 = iota
 	CommandStatus
+	CommandServiceReload
+	CommandCloseConnections
 )

+ 14 - 10
experimental/libbox/command_client.go

@@ -12,10 +12,10 @@ import (
 )
 
 type CommandClient struct {
-	sockPath string
-	handler  CommandClientHandler
-	conn     net.Conn
-	options  CommandClientOptions
+	sharedDirectory string
+	handler         CommandClientHandler
+	conn            net.Conn
+	options         CommandClientOptions
 }
 
 type CommandClientOptions struct {
@@ -32,17 +32,21 @@ type CommandClientHandler interface {
 
 func NewCommandClient(sharedDirectory string, handler CommandClientHandler, options *CommandClientOptions) *CommandClient {
 	return &CommandClient{
-		sockPath: filepath.Join(sharedDirectory, "command.sock"),
-		handler:  handler,
-		options:  common.PtrValueOrDefault(options),
+		sharedDirectory: sharedDirectory,
+		handler:         handler,
+		options:         common.PtrValueOrDefault(options),
 	}
 }
 
-func (c *CommandClient) Connect() error {
-	conn, err := net.DialUnix("unix", nil, &net.UnixAddr{
-		Name: c.sockPath,
+func clientConnect(sharedDirectory string) (net.Conn, error) {
+	return net.DialUnix("unix", nil, &net.UnixAddr{
+		Name: filepath.Join(sharedDirectory, "command.sock"),
 		Net:  "unix",
 	})
+}
+
+func (c *CommandClient) Connect() error {
+	conn, err := clientConnect(c.sharedDirectory)
 	if err != nil {
 		return err
 	}

+ 30 - 0
experimental/libbox/command_conntrack.go

@@ -0,0 +1,30 @@
+//go:build darwin
+
+package libbox
+
+import (
+	"encoding/binary"
+	"net"
+	runtimeDebug "runtime/debug"
+	"time"
+
+	"github.com/sagernet/sing-box/common/dialer/conntrack"
+)
+
+func ClientCloseConnections(sharedDirectory string) error {
+	conn, err := clientConnect(sharedDirectory)
+	if err != nil {
+		return err
+	}
+	defer conn.Close()
+	return binary.Write(conn, binary.BigEndian, uint8(CommandCloseConnections))
+}
+
+func (s *CommandServer) handleCloseConnections(conn net.Conn) error {
+	conntrack.Close()
+	go func() {
+		time.Sleep(time.Second)
+		runtimeDebug.FreeOSMemory()
+	}()
+	return nil
+}

+ 48 - 0
experimental/libbox/command_reload.go

@@ -0,0 +1,48 @@
+//go:build darwin
+
+package libbox
+
+import (
+	"encoding/binary"
+	"net"
+
+	E "github.com/sagernet/sing/common/exceptions"
+	"github.com/sagernet/sing/common/rw"
+)
+
+func ClientServiceReload(sharedDirectory string) error {
+	conn, err := clientConnect(sharedDirectory)
+	if err != nil {
+		return err
+	}
+	defer conn.Close()
+	err = binary.Write(conn, binary.BigEndian, uint8(CommandServiceReload))
+	if err != nil {
+		return err
+	}
+	var hasError bool
+	err = binary.Read(conn, binary.BigEndian, &hasError)
+	if err != nil {
+		return err
+	}
+	if hasError {
+		errorMessage, err := rw.ReadVString(conn)
+		if err != nil {
+			return err
+		}
+		return E.New(errorMessage)
+	}
+	return nil
+}
+
+func (s *CommandServer) handleServiceReload(conn net.Conn) error {
+	rErr := s.handler.ServiceReload()
+	err := binary.Write(conn, binary.BigEndian, rErr != nil)
+	if err != nil {
+		return err
+	}
+	if rErr != nil {
+		return rw.WriteVString(conn, rErr.Error())
+	}
+	return nil
+}

+ 11 - 1
experimental/libbox/command_server.go

@@ -18,6 +18,7 @@ import (
 type CommandServer struct {
 	sockPath string
 	listener net.Listener
+	handler  CommandServerHandler
 
 	access     sync.Mutex
 	savedLines *list.List[string]
@@ -25,9 +26,14 @@ type CommandServer struct {
 	observer   *observable.Observer[string]
 }
 
-func NewCommandServer(sharedDirectory string) *CommandServer {
+type CommandServerHandler interface {
+	ServiceReload() error
+}
+
+func NewCommandServer(sharedDirectory string, handler CommandServerHandler) *CommandServer {
 	server := &CommandServer{
 		sockPath:   filepath.Join(sharedDirectory, "command.sock"),
+		handler:    handler,
 		savedLines: new(list.List[string]),
 		subscriber: observable.NewSubscriber[string](128),
 	}
@@ -79,6 +85,10 @@ func (s *CommandServer) handleConnection(conn net.Conn) error {
 		return s.handleLogConn(conn)
 	case CommandStatus:
 		return s.handleStatusConn(conn)
+	case CommandServiceReload:
+		return s.handleServiceReload(conn)
+	case CommandCloseConnections:
+		return s.handleCloseConnections(conn)
 	default:
 		return E.New("unknown command: ", command)
 	}

+ 5 - 2
experimental/libbox/command_status.go

@@ -8,12 +8,14 @@ import (
 	"runtime"
 	"time"
 
+	"github.com/sagernet/sing-box/common/dialer/conntrack"
 	E "github.com/sagernet/sing/common/exceptions"
 )
 
 type StatusMessage struct {
-	Memory     int64
-	Goroutines int32
+	Memory      int64
+	Goroutines  int32
+	Connections int32
 }
 
 func readStatus() StatusMessage {
@@ -22,6 +24,7 @@ func readStatus() StatusMessage {
 	var message StatusMessage
 	message.Memory = int64(memStats.Sys - memStats.HeapReleased)
 	message.Goroutines = int32(runtime.NumGoroutine())
+	message.Connections = int32(conntrack.Count())
 	return message
 }