|
@@ -5,6 +5,7 @@ import (
|
|
|
"io"
|
|
|
"net"
|
|
|
"net/netip"
|
|
|
+ "sync"
|
|
|
"sync/atomic"
|
|
|
"time"
|
|
|
|
|
@@ -18,31 +19,35 @@ import (
|
|
|
"github.com/sagernet/sing/common/logger"
|
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
|
N "github.com/sagernet/sing/common/network"
|
|
|
+ "github.com/sagernet/sing/common/x/list"
|
|
|
)
|
|
|
|
|
|
var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
|
|
|
|
|
|
type ConnectionManager struct {
|
|
|
- logger logger.ContextLogger
|
|
|
- monitor *ConnectionMonitor
|
|
|
+ logger logger.ContextLogger
|
|
|
+ access sync.Mutex
|
|
|
+ connections list.List[io.Closer]
|
|
|
}
|
|
|
|
|
|
func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
|
|
|
return &ConnectionManager{
|
|
|
- logger: logger,
|
|
|
- monitor: NewConnectionMonitor(),
|
|
|
+ logger: logger,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func (m *ConnectionManager) Start(stage adapter.StartStage) error {
|
|
|
- if stage != adapter.StartStateInitialize {
|
|
|
- return nil
|
|
|
- }
|
|
|
- return m.monitor.Start()
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func (m *ConnectionManager) Close() error {
|
|
|
- return m.monitor.Close()
|
|
|
+ m.access.Lock()
|
|
|
+ defer m.access.Unlock()
|
|
|
+ for element := m.connections.Front(); element != nil; element = element.Next() {
|
|
|
+ common.Close(element.Value)
|
|
|
+ }
|
|
|
+ m.connections.Init()
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
|
@@ -57,95 +62,32 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co
|
|
|
remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
|
|
|
}
|
|
|
if err != nil {
|
|
|
+ err = E.Cause(err, "open outbound connection")
|
|
|
N.CloseOnHandshakeFailure(conn, onClose, err)
|
|
|
- m.logger.ErrorContext(ctx, "open outbound connection: ", err)
|
|
|
+ m.logger.ErrorContext(ctx, err)
|
|
|
return
|
|
|
}
|
|
|
err = N.ReportConnHandshakeSuccess(conn, remoteConn)
|
|
|
if err != nil {
|
|
|
+ err = E.Cause(err, "report handshake success")
|
|
|
remoteConn.Close()
|
|
|
N.CloseOnHandshakeFailure(conn, onClose, err)
|
|
|
- m.logger.ErrorContext(ctx, "report handshake success: ", err)
|
|
|
+ m.logger.ErrorContext(ctx, err)
|
|
|
return
|
|
|
}
|
|
|
+ m.access.Lock()
|
|
|
+ element := m.connections.PushBack(conn)
|
|
|
+ m.access.Unlock()
|
|
|
+ onClose = N.AppendClose(onClose, func(it error) {
|
|
|
+ m.access.Lock()
|
|
|
+ defer m.access.Unlock()
|
|
|
+ m.connections.Remove(element)
|
|
|
+ })
|
|
|
var done atomic.Bool
|
|
|
- if ctx.Done() != nil {
|
|
|
- onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
|
|
|
- }
|
|
|
go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
|
|
|
go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
|
|
|
}
|
|
|
|
|
|
-func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
|
|
- originSource := source
|
|
|
- originDestination := destination
|
|
|
- var readCounters, writeCounters []N.CountFunc
|
|
|
- for {
|
|
|
- source, readCounters = N.UnwrapCountReader(source, readCounters)
|
|
|
- destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
|
|
|
- if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
|
|
- cachedBuffer := cachedSrc.ReadCached()
|
|
|
- if cachedBuffer != nil {
|
|
|
- dataLen := cachedBuffer.Len()
|
|
|
- _, err := destination.Write(cachedBuffer.Bytes())
|
|
|
- cachedBuffer.Release()
|
|
|
- if err != nil {
|
|
|
- m.logger.ErrorContext(ctx, "connection upload payload: ", err)
|
|
|
- if done.Swap(true) {
|
|
|
- if onClose != nil {
|
|
|
- onClose(err)
|
|
|
- }
|
|
|
- }
|
|
|
- common.Close(originSource, originDestination)
|
|
|
- return
|
|
|
- }
|
|
|
- for _, counter := range readCounters {
|
|
|
- counter(int64(dataLen))
|
|
|
- }
|
|
|
- for _, counter := range writeCounters {
|
|
|
- counter(int64(dataLen))
|
|
|
- }
|
|
|
- }
|
|
|
- continue
|
|
|
- }
|
|
|
- break
|
|
|
- }
|
|
|
- _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
|
|
|
- if err != nil {
|
|
|
- common.Close(originSource, originDestination)
|
|
|
- } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
|
|
|
- err = duplexDst.CloseWrite()
|
|
|
- if err != nil {
|
|
|
- common.Close(originSource, originDestination)
|
|
|
- }
|
|
|
- } else {
|
|
|
- common.Close(originDestination)
|
|
|
- }
|
|
|
- if done.Swap(true) {
|
|
|
- if onClose != nil {
|
|
|
- onClose(err)
|
|
|
- }
|
|
|
- common.Close(originSource, originDestination)
|
|
|
- }
|
|
|
- if !direction {
|
|
|
- if err == nil {
|
|
|
- m.logger.DebugContext(ctx, "connection upload finished")
|
|
|
- } else if !E.IsClosedOrCanceled(err) {
|
|
|
- m.logger.ErrorContext(ctx, "connection upload closed: ", err)
|
|
|
- } else {
|
|
|
- m.logger.TraceContext(ctx, "connection upload closed")
|
|
|
- }
|
|
|
- } else {
|
|
|
- if err == nil {
|
|
|
- m.logger.DebugContext(ctx, "connection download finished")
|
|
|
- } else if !E.IsClosedOrCanceled(err) {
|
|
|
- m.logger.ErrorContext(ctx, "connection download closed: ", err)
|
|
|
- } else {
|
|
|
- m.logger.TraceContext(ctx, "connection download closed")
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
|
|
ctx = adapter.WithContext(ctx, &metadata)
|
|
|
var (
|
|
@@ -227,58 +169,91 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
|
|
|
ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout)
|
|
|
}
|
|
|
destination := bufio.NewPacketConn(remotePacketConn)
|
|
|
+ m.access.Lock()
|
|
|
+ element := m.connections.PushBack(conn)
|
|
|
+ m.access.Unlock()
|
|
|
+ onClose = N.AppendClose(onClose, func(it error) {
|
|
|
+ m.access.Lock()
|
|
|
+ defer m.access.Unlock()
|
|
|
+ m.connections.Remove(element)
|
|
|
+ })
|
|
|
var done atomic.Bool
|
|
|
- if ctx.Done() != nil {
|
|
|
- onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn))
|
|
|
- }
|
|
|
go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
|
|
|
go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
|
|
|
}
|
|
|
|
|
|
-func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
|
|
- _, err := bufio.CopyPacket(destination, source)
|
|
|
- /*var readCounters, writeCounters []N.CountFunc
|
|
|
- var cachedPackets []*N.PacketBuffer
|
|
|
+func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
|
|
originSource := source
|
|
|
+ originDestination := destination
|
|
|
+ var readCounters, writeCounters []N.CountFunc
|
|
|
for {
|
|
|
- source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
|
|
|
- destination, writeCounters = N.UnwrapCountPacketWriter(destination, writeCounters)
|
|
|
- if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
|
|
|
- packet := cachedReader.ReadCachedPacket()
|
|
|
- if packet != nil {
|
|
|
- cachedPackets = append(cachedPackets, packet)
|
|
|
- continue
|
|
|
+ source, readCounters = N.UnwrapCountReader(source, readCounters)
|
|
|
+ destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
|
|
|
+ if cachedSrc, isCached := source.(N.CachedReader); isCached {
|
|
|
+ cachedBuffer := cachedSrc.ReadCached()
|
|
|
+ if cachedBuffer != nil {
|
|
|
+ dataLen := cachedBuffer.Len()
|
|
|
+ _, err := destination.Write(cachedBuffer.Bytes())
|
|
|
+ cachedBuffer.Release()
|
|
|
+ if err != nil {
|
|
|
+ if done.Swap(true) {
|
|
|
+ onClose(err)
|
|
|
+ }
|
|
|
+ common.Close(originSource, originDestination)
|
|
|
+ if !direction {
|
|
|
+ m.logger.ErrorContext(ctx, "connection upload payload: ", err)
|
|
|
+ } else {
|
|
|
+ m.logger.ErrorContext(ctx, "connection download payload: ", err)
|
|
|
+ }
|
|
|
+ return
|
|
|
+ }
|
|
|
+ for _, counter := range readCounters {
|
|
|
+ counter(int64(dataLen))
|
|
|
+ }
|
|
|
+ for _, counter := range writeCounters {
|
|
|
+ counter(int64(dataLen))
|
|
|
+ }
|
|
|
}
|
|
|
+ continue
|
|
|
}
|
|
|
break
|
|
|
}
|
|
|
- var handled bool
|
|
|
- if natConn, isNatConn := source.(udpnat.Conn); isNatConn {
|
|
|
- natConn.SetHandler(&udpHijacker{
|
|
|
- ctx: ctx,
|
|
|
- logger: m.logger,
|
|
|
- source: natConn,
|
|
|
- destination: destination,
|
|
|
- direction: direction,
|
|
|
- readCounters: readCounters,
|
|
|
- writeCounters: writeCounters,
|
|
|
- done: done,
|
|
|
- onClose: onClose,
|
|
|
- })
|
|
|
- handled = true
|
|
|
- }
|
|
|
- if cachedPackets != nil {
|
|
|
- _, err := bufio.WritePacketWithPool(originSource, destination, cachedPackets, readCounters, writeCounters)
|
|
|
+ _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
|
|
|
+ if err != nil {
|
|
|
+ common.Close(originDestination)
|
|
|
+ } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
|
|
|
+ err = duplexDst.CloseWrite()
|
|
|
if err != nil {
|
|
|
- common.Close(source, destination)
|
|
|
- m.logger.ErrorContext(ctx, "packet upload payload: ", err)
|
|
|
- return
|
|
|
+ common.Close(originSource, originDestination)
|
|
|
}
|
|
|
+ } else {
|
|
|
+ common.Close(originDestination)
|
|
|
}
|
|
|
- if handled {
|
|
|
- return
|
|
|
+ if done.Swap(true) {
|
|
|
+ onClose(err)
|
|
|
+ common.Close(originSource, originDestination)
|
|
|
+ }
|
|
|
+ if !direction {
|
|
|
+ if err == nil {
|
|
|
+ m.logger.DebugContext(ctx, "connection upload finished")
|
|
|
+ } else if !E.IsClosedOrCanceled(err) {
|
|
|
+ m.logger.ErrorContext(ctx, "connection upload closed: ", err)
|
|
|
+ } else {
|
|
|
+ m.logger.TraceContext(ctx, "connection upload closed")
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if err == nil {
|
|
|
+ m.logger.DebugContext(ctx, "connection download finished")
|
|
|
+ } else if !E.IsClosedOrCanceled(err) {
|
|
|
+ m.logger.ErrorContext(ctx, "connection download closed: ", err)
|
|
|
+ } else {
|
|
|
+ m.logger.TraceContext(ctx, "connection download closed")
|
|
|
+ }
|
|
|
}
|
|
|
- _, err := bufio.CopyPacketWithCounters(destination, source, originSource, readCounters, writeCounters)*/
|
|
|
+}
|
|
|
+
|
|
|
+func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
|
|
|
+ _, err := bufio.CopyPacket(destination, source)
|
|
|
if !direction {
|
|
|
if E.IsClosedOrCanceled(err) {
|
|
|
m.logger.TraceContext(ctx, "packet upload closed")
|
|
@@ -293,58 +268,7 @@ func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.P
|
|
|
}
|
|
|
}
|
|
|
if !done.Swap(true) {
|
|
|
- if onClose != nil {
|
|
|
- onClose(err)
|
|
|
- }
|
|
|
+ onClose(err)
|
|
|
}
|
|
|
common.Close(source, destination)
|
|
|
}
|
|
|
-
|
|
|
-/*type udpHijacker struct {
|
|
|
- ctx context.Context
|
|
|
- logger logger.ContextLogger
|
|
|
- source io.Closer
|
|
|
- destination N.PacketWriter
|
|
|
- direction bool
|
|
|
- readCounters []N.CountFunc
|
|
|
- writeCounters []N.CountFunc
|
|
|
- done *atomic.Bool
|
|
|
- onClose N.CloseHandlerFunc
|
|
|
-}
|
|
|
-
|
|
|
-func (u *udpHijacker) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) {
|
|
|
- dataLen := buffer.Len()
|
|
|
- for _, counter := range u.readCounters {
|
|
|
- counter(int64(dataLen))
|
|
|
- }
|
|
|
- err := u.destination.WritePacket(buffer, source)
|
|
|
- if err != nil {
|
|
|
- common.Close(u.source, u.destination)
|
|
|
- u.logger.DebugContext(u.ctx, "packet upload closed: ", err)
|
|
|
- return
|
|
|
- }
|
|
|
- for _, counter := range u.writeCounters {
|
|
|
- counter(int64(dataLen))
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-func (u *udpHijacker) Close() error {
|
|
|
- var err error
|
|
|
- if !u.done.Swap(true) {
|
|
|
- err = common.Close(u.source, u.destination)
|
|
|
- if u.onClose != nil {
|
|
|
- u.onClose(net.ErrClosed)
|
|
|
- }
|
|
|
- }
|
|
|
- if u.direction {
|
|
|
- u.logger.TraceContext(u.ctx, "packet download closed")
|
|
|
- } else {
|
|
|
- u.logger.TraceContext(u.ctx, "packet upload closed")
|
|
|
- }
|
|
|
- return err
|
|
|
-}
|
|
|
-
|
|
|
-func (u *udpHijacker) Upstream() any {
|
|
|
- return u.destination
|
|
|
-}
|
|
|
-*/
|