|
@@ -5,6 +5,7 @@ import (
|
|
|
"io"
|
|
|
"net"
|
|
|
"net/netip"
|
|
|
+ "os"
|
|
|
"sync"
|
|
|
"sync/atomic"
|
|
|
"time"
|
|
@@ -13,6 +14,7 @@ import (
|
|
|
"github.com/sagernet/sing-box/common/dialer"
|
|
|
C "github.com/sagernet/sing-box/constant"
|
|
|
"github.com/sagernet/sing/common"
|
|
|
+ "github.com/sagernet/sing/common/buf"
|
|
|
"github.com/sagernet/sing/common/bufio"
|
|
|
"github.com/sagernet/sing/common/canceler"
|
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
@@ -190,14 +192,16 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
|
|
|
go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
|
|
|
}
|
|
|
|
|
|
-func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
|
|
- originSource := source
|
|
|
- originDestination := destination
|
|
|
+func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
|
|
+ var (
|
|
|
+ sourceReader io.Reader = source
|
|
|
+ destinationWriter io.Writer = destination
|
|
|
+ )
|
|
|
var readCounters, writeCounters []N.CountFunc
|
|
|
for {
|
|
|
- source, readCounters = N.UnwrapCountReader(source, readCounters)
|
|
|
- destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
|
|
|
- if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
|
|
+ sourceReader, readCounters = N.UnwrapCountReader(sourceReader, readCounters)
|
|
|
+ destinationWriter, writeCounters = N.UnwrapCountWriter(destinationWriter, writeCounters)
|
|
|
+ if cachedSrc, isCached := sourceReader.(N.CachedReader); isCached {
|
|
|
cachedBuffer := cachedSrc.ReadCached()
|
|
|
if cachedBuffer != nil {
|
|
|
dataLen := cachedBuffer.Len()
|
|
@@ -207,7 +211,7 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
|
|
|
if done.Swap(true) {
|
|
|
onClose(err)
|
|
|
}
|
|
|
- common.Close(originSource, originDestination)
|
|
|
+ common.Close(source, destination)
|
|
|
if !direction {
|
|
|
m.logger.ErrorContext(ctx, "connection upload payload: ", err)
|
|
|
} else {
|
|
@@ -226,9 +230,13 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
|
|
|
}
|
|
|
break
|
|
|
}
|
|
|
- if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destination); isEarlyConn && earlyConn.NeedHandshake() {
|
|
|
- _, err := destination.Write(nil)
|
|
|
+ if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destinationWriter); isEarlyConn && earlyConn.NeedHandshake() {
|
|
|
+ err := m.connectionCopyEarly(source, destination)
|
|
|
if err != nil {
|
|
|
+ if done.Swap(true) {
|
|
|
+ onClose(err)
|
|
|
+ }
|
|
|
+ common.Close(source, destination)
|
|
|
if !direction {
|
|
|
m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
|
|
|
} else {
|
|
@@ -237,20 +245,20 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
- _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
|
|
|
+ _, err := bufio.CopyWithCounters(destination, sourceReader, source, readCounters, writeCounters)
|
|
|
if err != nil {
|
|
|
- common.Close(originDestination)
|
|
|
+ common.Close(source, destination)
|
|
|
} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
|
|
|
err = duplexDst.CloseWrite()
|
|
|
if err != nil {
|
|
|
- common.Close(originSource, originDestination)
|
|
|
+ common.Close(source, destination)
|
|
|
}
|
|
|
} else {
|
|
|
- common.Close(originDestination)
|
|
|
+ destination.Close()
|
|
|
}
|
|
|
if done.Swap(true) {
|
|
|
onClose(err)
|
|
|
- common.Close(originSource, originDestination)
|
|
|
+ common.Close(source, destination)
|
|
|
}
|
|
|
if !direction {
|
|
|
if err == nil {
|
|
@@ -271,6 +279,28 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func (m *ConnectionManager) connectionCopyEarly(source net.Conn, destination io.Writer) error {
|
|
|
+ payload := buf.NewPacket()
|
|
|
+ defer payload.Release()
|
|
|
+ err := source.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
|
|
|
+ if err != nil {
|
|
|
+ if err == os.ErrInvalid {
|
|
|
+ return common.Error(destination.Write(nil))
|
|
|
+ }
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ _, err = payload.ReadOnceFrom(source)
|
|
|
+ if err != nil && !E.IsTimeout(err) {
|
|
|
+ return E.Cause(err, "read payload")
|
|
|
+ }
|
|
|
+ _ = source.SetReadDeadline(time.Time{})
|
|
|
+ _, err = destination.Write(payload.Bytes())
|
|
|
+ if err != nil {
|
|
|
+ return E.Cause(err, "write payload")
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
|
|
_, err := bufio.CopyPacket(destination, source)
|
|
|
if !direction {
|