Browse Source

Fix copy early conn

世界 7 months ago
parent
commit
4f3ee61104
1 changed files with 44 additions and 14 deletions
  1. 44 14
      route/conn.go

+ 44 - 14
route/conn.go

@@ -5,6 +5,7 @@ import (
 	"io"
 	"net"
 	"net/netip"
+	"os"
 	"sync"
 	"sync/atomic"
 	"time"
@@ -13,6 +14,7 @@ import (
 	"github.com/sagernet/sing-box/common/dialer"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
 	"github.com/sagernet/sing/common/bufio"
 	"github.com/sagernet/sing/common/canceler"
 	E "github.com/sagernet/sing/common/exceptions"
@@ -190,14 +192,16 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
 	go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
 }
 
-func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
-	originSource := source
-	originDestination := destination
+func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
+	var (
+		sourceReader      io.Reader = source
+		destinationWriter io.Writer = destination
+	)
 	var readCounters, writeCounters []N.CountFunc
 	for {
-		source, readCounters = N.UnwrapCountReader(source, readCounters)
-		destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
-		if cachedSrc, isCached := source.(N.CachedReader); isCached {
+		sourceReader, readCounters = N.UnwrapCountReader(sourceReader, readCounters)
+		destinationWriter, writeCounters = N.UnwrapCountWriter(destinationWriter, writeCounters)
+		if cachedSrc, isCached := sourceReader.(N.CachedReader); isCached {
 			cachedBuffer := cachedSrc.ReadCached()
 			if cachedBuffer != nil {
 				dataLen := cachedBuffer.Len()
@@ -207,7 +211,7 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
 					if done.Swap(true) {
 						onClose(err)
 					}
-					common.Close(originSource, originDestination)
+					common.Close(source, destination)
 					if !direction {
 						m.logger.ErrorContext(ctx, "connection upload payload: ", err)
 					} else {
@@ -226,9 +230,13 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
 		}
 		break
 	}
-	if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destination); isEarlyConn && earlyConn.NeedHandshake() {
-		_, err := destination.Write(nil)
+	if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destinationWriter); isEarlyConn && earlyConn.NeedHandshake() {
+		err := m.connectionCopyEarly(source, destination)
 		if err != nil {
+			if done.Swap(true) {
+				onClose(err)
+			}
+			common.Close(source, destination)
 			if !direction {
 				m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
 			} else {
@@ -237,20 +245,20 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
 			return
 		}
 	}
-	_, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
+	_, err := bufio.CopyWithCounters(destination, sourceReader, source, readCounters, writeCounters)
 	if err != nil {
-		common.Close(originDestination)
+		common.Close(source, destination)
 	} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
 		err = duplexDst.CloseWrite()
 		if err != nil {
-			common.Close(originSource, originDestination)
+			common.Close(source, destination)
 		}
 	} else {
-		common.Close(originDestination)
+		destination.Close()
 	}
 	if done.Swap(true) {
 		onClose(err)
-		common.Close(originSource, originDestination)
+		common.Close(source, destination)
 	}
 	if !direction {
 		if err == nil {
@@ -271,6 +279,28 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
 	}
 }
 
+func (m *ConnectionManager) connectionCopyEarly(source net.Conn, destination io.Writer) error {
+	payload := buf.NewPacket()
+	defer payload.Release()
+	err := source.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
+	if err != nil {
+		if err == os.ErrInvalid {
+			return common.Error(destination.Write(nil))
+		}
+		return err
+	}
+	_, err = payload.ReadOnceFrom(source)
+	if err != nil && !E.IsTimeout(err) {
+		return E.Cause(err, "read payload")
+	}
+	_ = source.SetReadDeadline(time.Time{})
+	_, err = destination.Write(payload.Bytes())
+	if err != nil {
+		return E.Cause(err, "write payload")
+	}
+	return nil
+}
+
 func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
 	_, err := bufio.CopyPacket(destination, source)
 	if !direction {