瀏覽代碼

Fix bad type cast in dispatcher

世界 2 年之前
父節點
當前提交
0a099d972b
共有 7 個文件被更改,包括 49 次插入13 次删除
  1. 8 3
      app/dispatcher/default.go
  2. 2 2
      common/buf/reader.go
  3. 4 4
      common/singbridge/handler.go
  4. 30 0
      common/singbridge/reader.go
  5. 1 1
      go.mod
  6. 2 2
      go.sum
  7. 2 1
      transport/internet/grpc/config_test.go

+ 8 - 3
app/dispatcher/default.go

@@ -30,10 +30,15 @@ var errSniffingTimeout = newError("timeout on sniffing")
 
 type cachedReader struct {
 	sync.Mutex
-	reader *pipe.Reader
+	reader timeoutReader
 	cache  buf.MultiBuffer
 }
 
+type timeoutReader interface {
+	buf.Reader
+	buf.TimeoutReader
+}
+
 func (r *cachedReader) Cache(b *buf.Buffer) {
 	mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
 	r.Lock()
@@ -84,7 +89,7 @@ func (r *cachedReader) Interrupt() {
 		r.cache = buf.ReleaseMulti(r.cache)
 	}
 	r.Unlock()
-	r.reader.Interrupt()
+	common.Interrupt(r.reader)
 }
 
 // DefaultDispatcher is a default implementation of Dispatcher.
@@ -345,7 +350,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
 		d.routedDispatch(ctx, outbound, destination)
 	} else {
 		cReader := &cachedReader{
-			reader: outbound.Reader.(*pipe.Reader),
+			reader: outbound.Reader.(timeoutReader),
 		}
 		outbound.Reader = cReader
 		result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)

+ 2 - 2
common/buf/reader.go

@@ -7,7 +7,7 @@ import (
 	"github.com/xtls/xray-core/common/errors"
 )
 
-func readOneUDP(r io.Reader) (*Buffer, error) {
+func ReadOneUDP(r io.Reader) (*Buffer, error) {
 	b := New()
 	for i := 0; i < 64; i++ {
 		_, err := b.ReadFrom(r)
@@ -166,7 +166,7 @@ type PacketReader struct {
 
 // ReadMultiBuffer implements Reader.
 func (r *PacketReader) ReadMultiBuffer() (MultiBuffer, error) {
-	b, err := readOneUDP(r.Reader)
+	b, err := ReadOneUDP(r.Reader)
 	if err != nil {
 		return nil, err
 	}

+ 4 - 4
common/singbridge/handler.go

@@ -2,11 +2,10 @@ package singbridge
 
 import (
 	"context"
-	"io"
 
+	"github.com/sagernet/sing/common/bufio"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
-	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/session"
@@ -40,9 +39,10 @@ func (d *Dispatcher) NewConnection(ctx context.Context, conn net.Conn, metadata
 }
 
 func (d *Dispatcher) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
+	packetConn := &PacketConn{bufio.NewBindPacketConn(conn.(net.PacketConn), metadata.Destination)}
 	return d.upstream.DispatchLink(ctx, ToDestination(metadata.Destination, net.Network_UDP), &transport.Link{
-		Reader: buf.NewPacketReader(conn.(io.Reader)),
-		Writer: buf.NewWriter(conn.(io.Writer)),
+		Reader: packetConn,
+		Writer: packetConn,
 	})
 }
 

+ 30 - 0
common/singbridge/reader.go

@@ -14,6 +14,9 @@ var (
 	_ buf.Reader        = (*Conn)(nil)
 	_ buf.TimeoutReader = (*Conn)(nil)
 	_ buf.Writer        = (*Conn)(nil)
+	_ buf.Reader        = (*PacketConn)(nil)
+	_ buf.TimeoutReader = (*PacketConn)(nil)
+	_ buf.Writer        = (*PacketConn)(nil)
 )
 
 type Conn struct {
@@ -64,3 +67,30 @@ func (c *Conn) WriteMultiBuffer(bufferList buf.MultiBuffer) error {
 	}
 	return nil
 }
+
+type PacketConn struct {
+	net.Conn
+}
+
+func (c *PacketConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
+	buffer, err := buf.ReadOneUDP(c.Conn)
+	if err != nil {
+		return nil, err
+	}
+	return buf.MultiBuffer{buffer}, nil
+}
+
+func (c *PacketConn) ReadMultiBufferTimeout(duration time.Duration) (buf.MultiBuffer, error) {
+	err := c.SetReadDeadline(time.Now().Add(duration))
+	if err != nil {
+		return nil, err
+	}
+	defer c.SetReadDeadline(time.Time{})
+	return c.ReadMultiBuffer()
+}
+
+func (c *PacketConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
+	mb, err := buf.WriteMultiBuffer(c.Conn, mb)
+	buf.ReleaseMulti(mb)
+	return err
+}

+ 1 - 1
go.mod

@@ -14,7 +14,7 @@ require (
 	github.com/quic-go/quic-go v0.34.0
 	github.com/refraction-networking/utls v1.3.2
 	github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207
-	github.com/sagernet/sing-mux v0.0.0-20230424061035-f6a6b7258c29
+	github.com/sagernet/sing-mux v0.0.0-20230425054943-ec2a972d0809
 	github.com/sagernet/sing-shadowsocks v0.2.1
 	github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c
 	github.com/seiflotfy/cuckoofilter v0.0.0-20220411075957-e3b120b3f5fb

+ 2 - 2
go.sum

@@ -148,8 +148,8 @@ github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR
 github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
 github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207 h1:+dDVjW20IT+e8maKryaDeRY2+RFmTFdrQeIzqE2WOss=
 github.com/sagernet/sing v0.2.5-0.20230423085534-0902e6216207/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w=
-github.com/sagernet/sing-mux v0.0.0-20230424061035-f6a6b7258c29 h1:JqNg7xbPHV7YtQFiaqYTRY79fjonNN8u4GMf4T6Lb3E=
-github.com/sagernet/sing-mux v0.0.0-20230424061035-f6a6b7258c29/go.mod h1:pF+RnLvCAOhECrvauy6LYOpBakJ/vuaF1Wm4lPsWryI=
+github.com/sagernet/sing-mux v0.0.0-20230425054943-ec2a972d0809 h1:OJsley0JzpFCkwrl4BU38YX+hhVUrcCasomJsv6g6CY=
+github.com/sagernet/sing-mux v0.0.0-20230425054943-ec2a972d0809/go.mod h1:pF+RnLvCAOhECrvauy6LYOpBakJ/vuaF1Wm4lPsWryI=
 github.com/sagernet/sing-shadowsocks v0.2.1 h1:FvdLQOqpvxHBJUcUe4fvgiYP2XLLwH5i1DtXQviVEPw=
 github.com/sagernet/sing-shadowsocks v0.2.1/go.mod h1:T/OgurSjsAe+Ug3+6PprXjmgHFmJidjOvQcjXGTKb3I=
 github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as=

+ 2 - 1
transport/internet/grpc/config_test.go

@@ -1,8 +1,9 @@
 package grpc
 
 import (
-	"github.com/stretchr/testify/assert"
 	"testing"
+
+	"github.com/stretchr/testify/assert"
 )
 
 func TestConfig_GetServiceName(t *testing.T) {