Bläddra i källkod

Improve direct copy

世界 2 år sedan
förälder
incheckning
c287731df9

+ 7 - 7
experimental/clashapi/trafficontrol/tracker.go

@@ -7,9 +7,9 @@ import (
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
-	"github.com/sagernet/sing-box/experimental/trackerconn"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/atomic"
+	"github.com/sagernet/sing/common/bufio"
 	N "github.com/sagernet/sing/common/network"
 
 	"github.com/gofrs/uuid/v5"
@@ -115,13 +115,13 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata Metadata, router ad
 	download := new(atomic.Int64)
 
 	t := &tcpTracker{
-		ExtendedConn: trackerconn.NewHook(conn, func(n int64) {
+		ExtendedConn: bufio.NewCounterConn(conn, []N.CountFunc{func(n int64) {
 			upload.Add(n)
 			manager.PushUploaded(n)
-		}, func(n int64) {
+		}}, []N.CountFunc{func(n int64) {
 			download.Add(n)
 			manager.PushDownloaded(n)
-		}),
+		}}),
 		manager: manager,
 		trackerInfo: &trackerInfo{
 			UUID:          uuid,
@@ -202,13 +202,13 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata Metadata, route
 	download := new(atomic.Int64)
 
 	ut := &udpTracker{
-		PacketConn: trackerconn.NewHookPacket(conn, func(n int64) {
+		PacketConn: bufio.NewCounterPacketConn(conn, []N.CountFunc{func(n int64) {
 			upload.Add(n)
 			manager.PushUploaded(n)
-		}, func(n int64) {
+		}}, []N.CountFunc{func(n int64) {
 			download.Add(n)
 			manager.PushDownloaded(n)
-		}),
+		}}),
 		manager: manager,
 		trackerInfo: &trackerInfo{
 			UUID:          uuid,

+ 0 - 108
experimental/trackerconn/conn.go

@@ -1,108 +0,0 @@
-package trackerconn
-
-import (
-	"net"
-
-	"github.com/sagernet/sing/common/atomic"
-	"github.com/sagernet/sing/common/buf"
-	"github.com/sagernet/sing/common/bufio"
-	N "github.com/sagernet/sing/common/network"
-)
-
-func New(conn net.Conn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *Conn {
-	return &Conn{bufio.NewExtendedConn(conn), readCounter, writeCounter}
-}
-
-func NewHook(conn net.Conn, readCounter func(n int64), writeCounter func(n int64)) *HookConn {
-	return &HookConn{bufio.NewExtendedConn(conn), readCounter, writeCounter}
-}
-
-type Conn struct {
-	N.ExtendedConn
-	readCounter  []*atomic.Int64
-	writeCounter []*atomic.Int64
-}
-
-func (c *Conn) Read(p []byte) (n int, err error) {
-	n, err = c.ExtendedConn.Read(p)
-	for _, counter := range c.readCounter {
-		counter.Add(int64(n))
-	}
-	return n, err
-}
-
-func (c *Conn) ReadBuffer(buffer *buf.Buffer) error {
-	err := c.ExtendedConn.ReadBuffer(buffer)
-	if err != nil {
-		return err
-	}
-	for _, counter := range c.readCounter {
-		counter.Add(int64(buffer.Len()))
-	}
-	return nil
-}
-
-func (c *Conn) Write(p []byte) (n int, err error) {
-	n, err = c.ExtendedConn.Write(p)
-	for _, counter := range c.writeCounter {
-		counter.Add(int64(n))
-	}
-	return n, err
-}
-
-func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
-	dataLen := int64(buffer.Len())
-	err := c.ExtendedConn.WriteBuffer(buffer)
-	if err != nil {
-		return err
-	}
-	for _, counter := range c.writeCounter {
-		counter.Add(dataLen)
-	}
-	return nil
-}
-
-func (c *Conn) Upstream() any {
-	return c.ExtendedConn
-}
-
-type HookConn struct {
-	N.ExtendedConn
-	readCounter  func(n int64)
-	writeCounter func(n int64)
-}
-
-func (c *HookConn) Read(p []byte) (n int, err error) {
-	n, err = c.ExtendedConn.Read(p)
-	c.readCounter(int64(n))
-	return n, err
-}
-
-func (c *HookConn) ReadBuffer(buffer *buf.Buffer) error {
-	err := c.ExtendedConn.ReadBuffer(buffer)
-	if err != nil {
-		return err
-	}
-	c.readCounter(int64(buffer.Len()))
-	return nil
-}
-
-func (c *HookConn) Write(p []byte) (n int, err error) {
-	n, err = c.ExtendedConn.Write(p)
-	c.writeCounter(int64(n))
-	return n, err
-}
-
-func (c *HookConn) WriteBuffer(buffer *buf.Buffer) error {
-	dataLen := int64(buffer.Len())
-	err := c.ExtendedConn.WriteBuffer(buffer)
-	if err != nil {
-		return err
-	}
-	c.writeCounter(dataLen)
-	return nil
-}
-
-func (c *HookConn) Upstream() any {
-	return c.ExtendedConn
-}

+ 0 - 76
experimental/trackerconn/packet_conn.go

@@ -1,76 +0,0 @@
-package trackerconn
-
-import (
-	"github.com/sagernet/sing/common/atomic"
-	"github.com/sagernet/sing/common/buf"
-	M "github.com/sagernet/sing/common/metadata"
-	N "github.com/sagernet/sing/common/network"
-)
-
-func NewPacket(conn N.PacketConn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *PacketConn {
-	return &PacketConn{conn, readCounter, writeCounter}
-}
-
-func NewHookPacket(conn N.PacketConn, readCounter func(n int64), writeCounter func(n int64)) *HookPacketConn {
-	return &HookPacketConn{conn, readCounter, writeCounter}
-}
-
-type PacketConn struct {
-	N.PacketConn
-	readCounter  []*atomic.Int64
-	writeCounter []*atomic.Int64
-}
-
-func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
-	destination, err = c.PacketConn.ReadPacket(buffer)
-	if err == nil {
-		for _, counter := range c.readCounter {
-			counter.Add(int64(buffer.Len()))
-		}
-	}
-	return
-}
-
-func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
-	dataLen := int64(buffer.Len())
-	err := c.PacketConn.WritePacket(buffer, destination)
-	if err != nil {
-		return err
-	}
-	for _, counter := range c.writeCounter {
-		counter.Add(dataLen)
-	}
-	return nil
-}
-
-func (c *PacketConn) Upstream() any {
-	return c.PacketConn
-}
-
-type HookPacketConn struct {
-	N.PacketConn
-	readCounter  func(n int64)
-	writeCounter func(n int64)
-}
-
-func (c *HookPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
-	destination, err = c.PacketConn.ReadPacket(buffer)
-	if err == nil {
-		c.readCounter(int64(buffer.Len()))
-	}
-	return
-}
-
-func (c *HookPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
-	dataLen := int64(buffer.Len())
-	err := c.PacketConn.WritePacket(buffer, destination)
-	if err != nil {
-		return err
-	}
-	c.writeCounter(dataLen)
-	return nil
-}
-
-func (c *HookPacketConn) Upstream() any {
-	return c.PacketConn
-}

+ 3 - 3
experimental/v2rayapi/stats.go

@@ -10,9 +10,9 @@ import (
 	"time"
 
 	"github.com/sagernet/sing-box/adapter"
-	"github.com/sagernet/sing-box/experimental/trackerconn"
 	"github.com/sagernet/sing-box/option"
 	"github.com/sagernet/sing/common/atomic"
+	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"
 	N "github.com/sagernet/sing/common/network"
 )
@@ -83,7 +83,7 @@ func (s *StatsService) RoutedConnection(inbound string, outbound string, user st
 		writeCounter = append(writeCounter, s.loadOrCreateCounter("user>>>"+user+">>>traffic>>>downlink"))
 	}
 	s.access.Unlock()
-	return trackerconn.New(conn, readCounter, writeCounter)
+	return bufio.NewInt64CounterConn(conn, readCounter, writeCounter)
 }
 
 func (s *StatsService) RoutedPacketConnection(inbound string, outbound string, user string, conn N.PacketConn) N.PacketConn {
@@ -109,7 +109,7 @@ func (s *StatsService) RoutedPacketConnection(inbound string, outbound string, u
 		writeCounter = append(writeCounter, s.loadOrCreateCounter("user>>>"+user+">>>traffic>>>downlink"))
 	}
 	s.access.Unlock()
-	return trackerconn.NewPacket(conn, readCounter, writeCounter)
+	return bufio.NewInt64CounterPacketConn(conn, readCounter, writeCounter)
 }
 
 func (s *StatsService) GetStats(ctx context.Context, request *GetStatsRequest) (*GetStatsResponse, error) {

+ 1 - 1
go.mod

@@ -24,7 +24,7 @@ require (
 	github.com/sagernet/gomobile v0.0.0-20230413023804-244d7ff07035
 	github.com/sagernet/quic-go v0.0.0-20230202071646-a8c8afb18b32
 	github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691
-	github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207
+	github.com/sagernet/sing v0.2.5-0.20230425122720-bf0aaacc6754
 	github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc
 	github.com/sagernet/sing-mux v0.0.0-20230517134606-1ebe6bb26646
 	github.com/sagernet/sing-shadowsocks v0.2.2-0.20230417102954-f77257340507

+ 2 - 2
go.sum

@@ -111,8 +111,8 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byL
 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU=
 github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
 github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
-github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207 h1:+dDVjW20IT+e8maKryaDeRY2+RFmTFdrQeIzqE2WOss=
-github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w=
+github.com/sagernet/sing v0.2.5-0.20230425122720-bf0aaacc6754 h1:y89Ntm1rrZPQVb1f+TKd4DH6NwX5XCyMIwoseTQd/5U=
+github.com/sagernet/sing v0.2.5-0.20230425122720-bf0aaacc6754/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w=
 github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc h1:hmbuqKv48SAjiKPoqtJGvS5pEHVPZjTHq9CPwQY2cZ4=
 github.com/sagernet/sing-dns v0.1.5-0.20230415085626-111ecf799dfc/go.mod h1:ZKuuqgsHRxDahYrzgSgy4vIAGGuKPlIf4hLcNzYzLkY=
 github.com/sagernet/sing-mux v0.0.0-20230517134606-1ebe6bb26646 h1:X3ADfMqeGns1Q1FlXc9kaL9FwW1UM6D6tEQo8jFstpc=

+ 101 - 0
outbound/dns.go

@@ -11,6 +11,7 @@ import (
 	"github.com/sagernet/sing-dns"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/buf"
+	"github.com/sagernet/sing/common/bufio"
 	"github.com/sagernet/sing/common/canceler"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
@@ -101,6 +102,24 @@ func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adap
 }
 
 func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
+	var reader N.PacketReader = conn
+	var counters []N.CountFunc
+	var cachedBuffer []*N.PacketBuffer
+	for {
+		reader, counters = N.UnwrapCountPacketReader(reader, counters)
+		if cachedReader, isCached := reader.(N.CachedPacketReader); isCached {
+			packet := cachedReader.ReadCachedPacket()
+			if packet != nil {
+				cachedBuffer = append([]*N.PacketBuffer{packet}, cachedBuffer...)
+				continue
+			}
+		}
+		if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
+			return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedBuffer, metadata)
+		}
+		break
+	}
+
 	ctx = adapter.WithContext(ctx, &metadata)
 	fastClose, cancel := common.ContextWithCancelCause(ctx)
 	timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
@@ -153,3 +172,85 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
 	})
 	return group.Run(fastClose)
 }
+
+func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
+	ctx = adapter.WithContext(ctx, &metadata)
+	fastClose, cancel := common.ContextWithCancelCause(ctx)
+	timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
+	var group task.Group
+	group.Append0(func(ctx context.Context) error {
+		var buffer *buf.Buffer
+		newBuffer := func() *buf.Buffer {
+			if buffer != nil {
+				buffer.Release()
+			}
+			buffer = buf.NewSize(dns.FixedPacketSize)
+			buffer.FullReset()
+			return buffer
+		}
+		for {
+			var message mDNS.Msg
+			var destination M.Socksaddr
+			var err error
+			if len(cached) > 0 {
+				packet := cached[0]
+				cached = cached[1:]
+				for _, counter := range readCounters {
+					counter(int64(packet.Buffer.Len()))
+				}
+				err = message.Unpack(packet.Buffer.Bytes())
+				packet.Buffer.Release()
+				if err != nil {
+					cancel(err)
+					return err
+				}
+				destination = packet.Destination
+			} else {
+				destination, err = readWaiter.WaitReadPacket(newBuffer)
+				if err != nil {
+					if buffer != nil {
+						buffer.Release()
+					}
+					cancel(err)
+					return err
+				}
+				for _, counter := range readCounters {
+					counter(int64(buffer.Len()))
+				}
+				err = message.Unpack(buffer.Bytes())
+				buffer.Release()
+				if err != nil {
+					cancel(err)
+					return err
+				}
+				timeout.Update()
+			}
+			metadataInQuery := metadata
+			go func() error {
+				response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
+				if err != nil {
+					cancel(err)
+					return err
+				}
+				timeout.Update()
+				responseBuffer := buf.NewPacket()
+				n, err := response.PackBuffer(responseBuffer.FreeBytes())
+				if err != nil {
+					cancel(err)
+					responseBuffer.Release()
+					return err
+				}
+				responseBuffer.Truncate(len(n))
+				err = conn.WritePacket(responseBuffer, destination)
+				if err != nil {
+					cancel(err)
+				}
+				return err
+			}()
+		}
+	})
+	group.Cleanup(func() {
+		conn.Close()
+	})
+	return group.Run(fastClose)
+}