1
0
世界 2 жил өмнө
parent
commit
5ce3ddee9b

+ 1 - 1
go.mod

@@ -24,7 +24,7 @@ require (
 	github.com/sagernet/gomobile v0.0.0-20221130124640-349ebaa752ca
 	github.com/sagernet/quic-go v0.0.0-20230202071646-a8c8afb18b32
 	github.com/sagernet/reality v0.0.0-20230226124550-f98d51fa21b5
-	github.com/sagernet/sing v0.1.8-0.20230221060643-3401d210384b
+	github.com/sagernet/sing v0.1.8-0.20230226145949-3f0b21359af6
 	github.com/sagernet/sing-dns v0.1.4
 	github.com/sagernet/sing-shadowsocks v0.1.2-0.20230221080503-769c01d6bba9
 	github.com/sagernet/sing-shadowtls v0.0.0-20230221123345-78e50cd7b587

+ 2 - 2
go.sum

@@ -129,8 +129,8 @@ github.com/sagernet/reality v0.0.0-20230226124550-f98d51fa21b5 h1:yDic66vLGsY3zq
 github.com/sagernet/reality v0.0.0-20230226124550-f98d51fa21b5/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU=
 github.com/sagernet/sing v0.0.0-20220812082120-05f9836bff8f/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
 github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
-github.com/sagernet/sing v0.1.8-0.20230221060643-3401d210384b h1:Ji2AfGlc4j9AitobOx4k3BCj7eS5nSxL1cgaL81zvlo=
-github.com/sagernet/sing v0.1.8-0.20230221060643-3401d210384b/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
+github.com/sagernet/sing v0.1.8-0.20230226145949-3f0b21359af6 h1:QLfccQ8S1nqw5+xYEM/xLXQDq70BjAeyuVWluIEytww=
+github.com/sagernet/sing v0.1.8-0.20230226145949-3f0b21359af6/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
 github.com/sagernet/sing-dns v0.1.4 h1:7VxgeoSCiiazDSaXXQVcvrTBxFpOePPq/4XdgnUDN+0=
 github.com/sagernet/sing-dns v0.1.4/go.mod h1:1+6pCa48B1AI78lD+/i/dLgpw4MwfnsSpZo0Ds8wzzk=
 github.com/sagernet/sing-shadowsocks v0.1.2-0.20230221080503-769c01d6bba9 h1:qS39eA4C7x+zhEkySbASrtmb6ebdy5v0y2M6mgkmSO0=

+ 21 - 43
outbound/default.go

@@ -39,30 +39,6 @@ func (a *myOutboundAdapter) Network() []string {
 }
 
 func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error {
-	ctx = adapter.WithContext(ctx, &metadata)
-	var outConn net.Conn
-	var err error
-	if len(metadata.DestinationAddresses) > 0 {
-		outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses)
-	} else {
-		outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
-	}
-	if err != nil {
-		return N.HandshakeFailure(conn, err)
-	}
-	if cachedReader, isCached := conn.(N.CachedReader); isCached {
-		payload := cachedReader.ReadCached()
-		if payload != nil && !payload.IsEmpty() {
-			_, err = outConn.Write(payload.Bytes())
-			if err != nil {
-				return err
-			}
-		}
-	}
-	return bufio.CopyConn(ctx, conn, outConn)
-}
-
-func NewEarlyConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error {
 	ctx = adapter.WithContext(ctx, &metadata)
 	var outConn net.Conn
 	var err error
@@ -111,28 +87,30 @@ func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) erro
 			return bufio.CopyConn(ctx, conn, serverConn)
 		}
 	}
-	_payload := buf.StackNew()
-	payload := common.Dup(_payload)
-	err := conn.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
-	if err != os.ErrInvalid {
-		if err != nil {
-			return err
-		}
-		_, err = payload.ReadOnceFrom(conn)
-		if err != nil && !E.IsTimeout(err) {
-			return E.Cause(err, "read payload")
+	if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](conn); isEarlyConn && earlyConn.NeedHandshake() {
+		_payload := buf.StackNew()
+		payload := common.Dup(_payload)
+		err := conn.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
+		if err != os.ErrInvalid {
+			if err != nil {
+				return err
+			}
+			_, err = payload.ReadOnceFrom(conn)
+			if err != nil && !E.IsTimeout(err) {
+				return E.Cause(err, "read payload")
+			}
+			err = conn.SetReadDeadline(time.Time{})
+			if err != nil {
+				payload.Release()
+				return err
+			}
 		}
-		err = conn.SetReadDeadline(time.Time{})
+		_, err = serverConn.Write(payload.Bytes())
 		if err != nil {
-			payload.Release()
-			return err
+			return N.HandshakeFailure(conn, err)
 		}
+		runtime.KeepAlive(_payload)
+		payload.Release()
 	}
-	_, err = serverConn.Write(payload.Bytes())
-	if err != nil {
-		return N.HandshakeFailure(conn, err)
-	}
-	runtime.KeepAlive(_payload)
-	payload.Release()
 	return bufio.CopyConn(ctx, conn, serverConn)
 }

+ 1 - 1
outbound/shadowsocks.go

@@ -125,7 +125,7 @@ func (h *Shadowsocks) ListenPacket(ctx context.Context, destination M.Socksaddr)
 }
 
 func (h *Shadowsocks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
-	return NewEarlyConnection(ctx, h, conn, metadata)
+	return NewConnection(ctx, h, conn, metadata)
 }
 
 func (h *Shadowsocks) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {

+ 1 - 1
outbound/trojan.go

@@ -96,7 +96,7 @@ func (h *Trojan) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
 }
 
 func (h *Trojan) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
-	return NewEarlyConnection(ctx, h, conn, metadata)
+	return NewConnection(ctx, h, conn, metadata)
 }
 
 func (h *Trojan) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {

+ 1 - 1
outbound/vless.go

@@ -135,7 +135,7 @@ func (h *VLESS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.
 }
 
 func (h *VLESS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
-	return NewEarlyConnection(ctx, h, conn, metadata)
+	return NewConnection(ctx, h, conn, metadata)
 }
 
 func (h *VLESS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {

+ 1 - 1
outbound/vmess.go

@@ -133,7 +133,7 @@ func (h *VMess) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.
 }
 
 func (h *VMess) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
-	return NewEarlyConnection(ctx, h, conn, metadata)
+	return NewConnection(ctx, h, conn, metadata)
 }
 
 func (h *VMess) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {

+ 10 - 0
transport/trojan/protocol.go

@@ -26,6 +26,8 @@ const (
 
 var CRLF = []byte{'\r', '\n'}
 
+var _ N.EarlyConn = (*ClientConn)(nil)
+
 type ClientConn struct {
 	N.ExtendedConn
 	key           [KeyLength]byte
@@ -41,6 +43,10 @@ func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr)
 	}
 }
 
+func (c *ClientConn) NeedHandshake() bool {
+	return !c.headerWritten
+}
+
 func (c *ClientConn) Write(p []byte) (n int, err error) {
 	if c.headerWritten {
 		return c.ExtendedConn.Write(p)
@@ -101,6 +107,10 @@ func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn {
 	}
 }
 
+func (c *ClientPacketConn) NeedHandshake() bool {
+	return !c.headerWritten
+}
+
 func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
 	return ReadPacket(c.Conn, buffer)
 }

+ 7 - 0
transport/vless/client.go

@@ -10,6 +10,7 @@ import (
 	"github.com/sagernet/sing/common/buf"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
 
 	"github.com/gofrs/uuid"
 )
@@ -82,6 +83,8 @@ func (c *Client) DialEarlyXUDPPacketConn(conn net.Conn, destination M.Socksaddr)
 	return vmess.NewXUDPConn(&Conn{Conn: conn, protocolConn: conn, key: c.key, command: vmess.CommandMux, destination: destination, flow: c.flow}, destination), nil
 }
 
+var _ N.EarlyConn = (*Conn)(nil)
+
 type Conn struct {
 	net.Conn
 	protocolConn   net.Conn
@@ -93,6 +96,10 @@ type Conn struct {
 	responseRead   bool
 }
 
+func (c *Conn) NeedHandshake() bool {
+	return !c.requestWritten
+}
+
 func (c *Conn) Read(b []byte) (n int, err error) {
 	if !c.responseRead {
 		err = ReadResponse(c.Conn)