浏览代码

Fix clash tracker

世界 3 年之前
父节点
当前提交
29d08e63b5

+ 2 - 4
experimental/clashapi/server.go

@@ -31,9 +31,7 @@ import (
 )
 )
 
 
 func init() {
 func init() {
-	experimental.RegisterClashServerConstructor(func(router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) {
-		return NewServer(router, logFactory, options)
-	})
+	experimental.RegisterClashServerConstructor(NewServer)
 }
 }
 
 
 var _ adapter.ClashServer = (*Server)(nil)
 var _ adapter.ClashServer = (*Server)(nil)
@@ -51,7 +49,7 @@ type Server struct {
 	cacheFile      adapter.ClashCacheFile
 	cacheFile      adapter.ClashCacheFile
 }
 }
 
 
-func NewServer(router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (*Server, error) {
+func NewServer(router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) {
 	trafficManager := trafficontrol.NewManager()
 	trafficManager := trafficontrol.NewManager()
 	chiRouter := chi.NewRouter()
 	chiRouter := chi.NewRouter()
 	server := &Server{
 	server := &Server{

+ 16 - 4
experimental/clashapi/trafficontrol/tracker.go

@@ -101,8 +101,14 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata Metadata, router ad
 	download := atomic.NewInt64(0)
 	download := atomic.NewInt64(0)
 
 
 	t := &tcpTracker{
 	t := &tcpTracker{
-		ExtendedConn: trackerconn.New(conn, upload, download, directIO),
-		manager:      manager,
+		ExtendedConn: trackerconn.NewHook(conn, func(n int64) {
+			upload.Add(n)
+			manager.PushUploaded(n)
+		}, func(n int64) {
+			download.Add(n)
+			manager.PushDownloaded(n)
+		}, directIO),
+		manager: manager,
 		trackerInfo: &trackerInfo{
 		trackerInfo: &trackerInfo{
 			UUID:          uuid,
 			UUID:          uuid,
 			Start:         time.Now(),
 			Start:         time.Now(),
@@ -182,8 +188,14 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata Metadata, route
 	download := atomic.NewInt64(0)
 	download := atomic.NewInt64(0)
 
 
 	ut := &udpTracker{
 	ut := &udpTracker{
-		PacketConn: trackerconn.NewPacket(conn, upload, download),
-		manager:    manager,
+		PacketConn: trackerconn.NewHookPacket(conn, func(n int64) {
+			upload.Add(n)
+			manager.PushUploaded(n)
+		}, func(n int64) {
+			download.Add(n)
+			manager.PushDownloaded(n)
+		}),
+		manager: manager,
 		trackerInfo: &trackerInfo{
 		trackerInfo: &trackerInfo{
 			UUID:          uuid,
 			UUID:          uuid,
 			Start:         time.Now(),
 			Start:         time.Now(),

+ 74 - 0
experimental/trackerconn/conn.go

@@ -20,6 +20,15 @@ func New(conn net.Conn, readCounter *atomic.Int64, writeCounter *atomic.Int64, d
 	}
 	}
 }
 }
 
 
+func NewHook(conn net.Conn, readCounter func(n int64), writeCounter func(n int64), direct bool) N.ExtendedConn {
+	trackerConn := &HookConn{bufio.NewExtendedConn(conn), readCounter, writeCounter}
+	if direct {
+		return (*DirectHookConn)(trackerConn)
+	} else {
+		return trackerConn
+	}
+}
+
 type Conn struct {
 type Conn struct {
 	N.ExtendedConn
 	N.ExtendedConn
 	readCounter  *atomic.Int64
 	readCounter  *atomic.Int64
@@ -61,6 +70,47 @@ func (c *Conn) Upstream() any {
 	return c.ExtendedConn
 	return c.ExtendedConn
 }
 }
 
 
+type HookConn struct {
+	N.ExtendedConn
+	readCounter  func(n int64)
+	writeCounter func(n int64)
+}
+
+func (c *HookConn) Read(p []byte) (n int, err error) {
+	n, err = c.ExtendedConn.Read(p)
+	c.readCounter(int64(n))
+	return n, err
+}
+
+func (c *HookConn) ReadBuffer(buffer *buf.Buffer) error {
+	err := c.ExtendedConn.ReadBuffer(buffer)
+	if err != nil {
+		return err
+	}
+	c.readCounter(int64(buffer.Len()))
+	return nil
+}
+
+func (c *HookConn) Write(p []byte) (n int, err error) {
+	n, err = c.ExtendedConn.Write(p)
+	c.writeCounter(int64(n))
+	return n, err
+}
+
+func (c *HookConn) WriteBuffer(buffer *buf.Buffer) error {
+	dataLen := int64(buffer.Len())
+	err := c.ExtendedConn.WriteBuffer(buffer)
+	if err != nil {
+		return err
+	}
+	c.writeCounter(dataLen)
+	return nil
+}
+
+func (c *HookConn) Upstream() any {
+	return c.ExtendedConn
+}
+
 type DirectConn Conn
 type DirectConn Conn
 
 
 func (c *DirectConn) WriteTo(w io.Writer) (n int64, err error) {
 func (c *DirectConn) WriteTo(w io.Writer) (n int64, err error) {
@@ -84,3 +134,27 @@ func (c *DirectConn) ReadFrom(r io.Reader) (n int64, err error) {
 		return bufio.Copy((*Conn)(c), r)
 		return bufio.Copy((*Conn)(c), r)
 	}
 	}
 }
 }
+
+type DirectHookConn HookConn
+
+func (c *DirectHookConn) WriteTo(w io.Writer) (n int64, err error) {
+	reader := N.UnwrapReader(c.ExtendedConn)
+	if wt, ok := reader.(io.WriterTo); ok {
+		n, err = wt.WriteTo(w)
+		c.readCounter(n)
+		return
+	} else {
+		return bufio.Copy(w, (*HookConn)(c))
+	}
+}
+
+func (c *DirectHookConn) ReadFrom(r io.Reader) (n int64, err error) {
+	writer := N.UnwrapWriter(c.ExtendedConn)
+	if rt, ok := writer.(io.ReaderFrom); ok {
+		n, err = rt.ReadFrom(r)
+		c.writeCounter(n)
+		return
+	} else {
+		return bufio.Copy((*HookConn)(c), r)
+	}
+}

+ 36 - 4
experimental/trackerconn/packet_conn.go

@@ -8,16 +8,20 @@ import (
 	"go.uber.org/atomic"
 	"go.uber.org/atomic"
 )
 )
 
 
+func NewPacket(conn N.PacketConn, readCounter *atomic.Int64, writeCounter *atomic.Int64) *PacketConn {
+	return &PacketConn{conn, readCounter, writeCounter}
+}
+
+func NewHookPacket(conn N.PacketConn, readCounter func(n int64), writeCounter func(n int64)) *HookPacketConn {
+	return &HookPacketConn{conn, readCounter, writeCounter}
+}
+
 type PacketConn struct {
 type PacketConn struct {
 	N.PacketConn
 	N.PacketConn
 	readCounter  *atomic.Int64
 	readCounter  *atomic.Int64
 	writeCounter *atomic.Int64
 	writeCounter *atomic.Int64
 }
 }
 
 
-func NewPacket(conn N.PacketConn, readCounter *atomic.Int64, writeCounter *atomic.Int64) *PacketConn {
-	return &PacketConn{conn, readCounter, writeCounter}
-}
-
 func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
 func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
 	destination, err = c.PacketConn.ReadPacket(buffer)
 	destination, err = c.PacketConn.ReadPacket(buffer)
 	if err == nil {
 	if err == nil {
@@ -39,3 +43,31 @@ func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) er
 func (c *PacketConn) Upstream() any {
 func (c *PacketConn) Upstream() any {
 	return c.PacketConn
 	return c.PacketConn
 }
 }
+
+type HookPacketConn struct {
+	N.PacketConn
+	readCounter  func(n int64)
+	writeCounter func(n int64)
+}
+
+func (c *HookPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
+	destination, err = c.PacketConn.ReadPacket(buffer)
+	if err == nil {
+		c.readCounter(int64(buffer.Len()))
+	}
+	return
+}
+
+func (c *HookPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
+	dataLen := int64(buffer.Len())
+	err := c.PacketConn.WritePacket(buffer, destination)
+	if err != nil {
+		return err
+	}
+	c.writeCounter(dataLen)
+	return nil
+}
+
+func (c *HookPacketConn) Upstream() any {
+	return c.PacketConn
+}