Browse Source

Add direct io option for clash api

世界 3 years ago
parent
commit
7f9c870bba

+ 3 - 1
experimental/clashapi/server.go

@@ -38,6 +38,7 @@ type Server struct {
 	trafficManager *trafficontrol.Manager
 	urlTestHistory *urltest.HistoryStorage
 	tcpListener    net.Listener
+	directIO       bool
 	mode           string
 	storeSelected  bool
 	cacheFile      adapter.ClashCacheFile
@@ -55,6 +56,7 @@ func NewServer(router adapter.Router, logFactory log.ObservableFactory, options
 		},
 		trafficManager: trafficManager,
 		urlTestHistory: urltest.NewHistoryStorage(),
+		directIO:       options.DirectIO,
 		mode:           strings.ToLower(options.DefaultMode),
 	}
 	if server.mode == "" {
@@ -149,7 +151,7 @@ func (s *Server) HistoryStorage() *urltest.HistoryStorage {
 }
 
 func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule) (net.Conn, adapter.Tracker) {
-	tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, castMetadata(metadata), s.router, matchedRule)
+	tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, castMetadata(metadata), s.router, matchedRule, s.directIO)
 	return tracker, tracker
 }
 

+ 34 - 50
experimental/clashapi/trafficontrol/tracker.go

@@ -6,9 +6,8 @@ import (
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/experimental/trackerconn"
 	"github.com/sagernet/sing/common"
-	"github.com/sagernet/sing/common/buf"
-	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 
 	"github.com/gofrs/uuid"
@@ -45,7 +44,7 @@ type trackerInfo struct {
 }
 
 type tcpTracker struct {
-	net.Conn `json:"-"`
+	N.ExtendedConn `json:"-"`
 	*trackerInfo
 	manager *Manager
 }
@@ -54,25 +53,9 @@ func (tt *tcpTracker) ID() string {
 	return tt.UUID.String()
 }
 
-func (tt *tcpTracker) Read(b []byte) (int, error) {
-	n, err := tt.Conn.Read(b)
-	upload := int64(n)
-	tt.manager.PushUploaded(upload)
-	tt.UploadTotal.Add(upload)
-	return n, err
-}
-
-func (tt *tcpTracker) Write(b []byte) (int, error) {
-	n, err := tt.Conn.Write(b)
-	download := int64(n)
-	tt.manager.PushDownloaded(download)
-	tt.DownloadTotal.Add(download)
-	return n, err
-}
-
 func (tt *tcpTracker) Close() error {
 	tt.manager.Leave(tt)
-	return tt.Conn.Close()
+	return tt.ExtendedConn.Close()
 }
 
 func (tt *tcpTracker) Leave() {
@@ -80,10 +63,18 @@ func (tt *tcpTracker) Leave() {
 }
 
 func (tt *tcpTracker) Upstream() any {
-	return tt.Conn
+	return tt.ExtendedConn
+}
+
+func (tt *tcpTracker) ReaderReplaceable() bool {
+	return true
+}
+
+func (tt *tcpTracker) WriterReplaceable() bool {
+	return true
 }
 
-func NewTCPTracker(conn net.Conn, manager *Manager, metadata Metadata, router adapter.Router, rule adapter.Rule) *tcpTracker {
+func NewTCPTracker(conn net.Conn, manager *Manager, metadata Metadata, router adapter.Router, rule adapter.Rule, directIO bool) *tcpTracker {
 	uuid, _ := uuid.NewV4()
 
 	var chain []string
@@ -106,17 +97,20 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata Metadata, router ad
 		next = group.Now()
 	}
 
+	upload := atomic.NewInt64(0)
+	download := atomic.NewInt64(0)
+
 	t := &tcpTracker{
-		Conn:    conn,
-		manager: manager,
+		ExtendedConn: trackerconn.New(conn, upload, download, directIO),
+		manager:      manager,
 		trackerInfo: &trackerInfo{
 			UUID:          uuid,
 			Start:         time.Now(),
 			Metadata:      metadata,
 			Chain:         common.Reverse(chain),
 			Rule:          "",
-			UploadTotal:   atomic.NewInt64(0),
-			DownloadTotal: atomic.NewInt64(0),
+			UploadTotal:   upload,
+			DownloadTotal: download,
 		},
 	}
 
@@ -140,27 +134,6 @@ func (ut *udpTracker) ID() string {
 	return ut.UUID.String()
 }
 
-func (ut *udpTracker) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
-	destination, err = ut.PacketConn.ReadPacket(buffer)
-	if err == nil {
-		upload := int64(buffer.Len())
-		ut.manager.PushUploaded(upload)
-		ut.UploadTotal.Add(upload)
-	}
-	return
-}
-
-func (ut *udpTracker) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
-	download := int64(buffer.Len())
-	err := ut.PacketConn.WritePacket(buffer, destination)
-	if err != nil {
-		return err
-	}
-	ut.manager.PushDownloaded(download)
-	ut.DownloadTotal.Add(download)
-	return nil
-}
-
 func (ut *udpTracker) Close() error {
 	ut.manager.Leave(ut)
 	return ut.PacketConn.Close()
@@ -174,6 +147,14 @@ func (ut *udpTracker) Upstream() any {
 	return ut.PacketConn
 }
 
+func (ut *udpTracker) ReaderReplaceable() bool {
+	return true
+}
+
+func (ut *udpTracker) WriterReplaceable() bool {
+	return true
+}
+
 func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata Metadata, router adapter.Router, rule adapter.Rule) *udpTracker {
 	uuid, _ := uuid.NewV4()
 
@@ -197,8 +178,11 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata Metadata, route
 		next = group.Now()
 	}
 
+	upload := atomic.NewInt64(0)
+	download := atomic.NewInt64(0)
+
 	ut := &udpTracker{
-		PacketConn: conn,
+		PacketConn: trackerconn.NewPacket(conn, upload, download),
 		manager:    manager,
 		trackerInfo: &trackerInfo{
 			UUID:          uuid,
@@ -206,8 +190,8 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata Metadata, route
 			Metadata:      metadata,
 			Chain:         common.Reverse(chain),
 			Rule:          "",
-			UploadTotal:   atomic.NewInt64(0),
-			DownloadTotal: atomic.NewInt64(0),
+			UploadTotal:   upload,
+			DownloadTotal: download,
 		},
 	}
 

+ 82 - 0
experimental/trackerconn/conn.go

@@ -0,0 +1,82 @@
+package trackerconn
+
+import (
+	"io"
+	"net"
+
+	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/bufio"
+	N "github.com/sagernet/sing/common/network"
+
+	"go.uber.org/atomic"
+)
+
+func New(conn net.Conn, readCounter *atomic.Int64, writeCounter *atomic.Int64, direct bool) N.ExtendedConn {
+	trackerConn := &Conn{bufio.NewExtendedConn(conn), readCounter, writeCounter}
+	if direct {
+		return (*DirectConn)(trackerConn)
+	} else {
+		return trackerConn
+	}
+}
+
+type Conn struct {
+	N.ExtendedConn
+	readCounter  *atomic.Int64
+	writeCounter *atomic.Int64
+}
+
+func (c *Conn) Read(p []byte) (n int, err error) {
+	n, err = c.ExtendedConn.Read(p)
+	c.readCounter.Add(int64(n))
+	return n, err
+}
+
+func (c *Conn) ReadBuffer(buffer *buf.Buffer) error {
+	err := c.ExtendedConn.ReadBuffer(buffer)
+	if err != nil {
+		return err
+	}
+	c.readCounter.Add(int64(buffer.Len()))
+	return nil
+}
+
+func (c *Conn) Write(p []byte) (n int, err error) {
+	n, err = c.ExtendedConn.Write(p)
+	c.writeCounter.Add(int64(n))
+	return n, err
+}
+
+func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
+	dataLen := int64(buffer.Len())
+	err := c.ExtendedConn.WriteBuffer(buffer)
+	if err != nil {
+		return err
+	}
+	c.writeCounter.Add(dataLen)
+	return nil
+}
+
+type DirectConn Conn
+
+func (c *DirectConn) WriteTo(w io.Writer) (n int64, err error) {
+	reader := N.UnwrapReader(c.ExtendedConn)
+	if wt, ok := reader.(io.WriterTo); ok {
+		n, err = wt.WriteTo(w)
+		c.readCounter.Add(n)
+		return
+	} else {
+		return bufio.Copy(w, (*Conn)(c))
+	}
+}
+
+func (c *DirectConn) ReadFrom(r io.Reader) (n int64, err error) {
+	writer := N.UnwrapWriter(c.ExtendedConn)
+	if rt, ok := writer.(io.ReaderFrom); ok {
+		n, err = rt.ReadFrom(r)
+		c.writeCounter.Add(n)
+		return
+	} else {
+		return bufio.Copy((*Conn)(c), r)
+	}
+}

+ 37 - 0
experimental/trackerconn/packet_conn.go

@@ -0,0 +1,37 @@
+package trackerconn
+
+import (
+	"github.com/sagernet/sing/common/buf"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+
+	"go.uber.org/atomic"
+)
+
+type PacketConn struct {
+	N.PacketConn
+	readCounter  *atomic.Int64
+	writeCounter *atomic.Int64
+}
+
+func NewPacket(conn N.PacketConn, readCounter *atomic.Int64, writeCounter *atomic.Int64) *PacketConn {
+	return &PacketConn{conn, readCounter, writeCounter}
+}
+
+func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
+	destination, err = c.PacketConn.ReadPacket(buffer)
+	if err == nil {
+		c.readCounter.Add(int64(buffer.Len()))
+	}
+	return
+}
+
+func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
+	dataLen := int64(buffer.Len())
+	err := c.PacketConn.WritePacket(buffer, destination)
+	if err != nil {
+		return err
+	}
+	c.writeCounter.Add(dataLen)
+	return nil
+}

+ 1 - 0
option/clash.go

@@ -5,6 +5,7 @@ type ClashAPIOptions struct {
 	ExternalUI         string `json:"external_ui,omitempty"`
 	Secret             string `json:"secret,omitempty"`
 
+	DirectIO      bool   `json:"direct_io,omitempty"`
 	DefaultMode   string `json:"default_mode,omitempty"`
 	StoreSelected bool   `json:"store_selected,omitempty"`
 	CacheFile     string `json:"cache_file,omitempty"`