conn.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. package route
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "net/netip"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. "github.com/sagernet/sing-box/adapter"
  11. "github.com/sagernet/sing-box/common/dialer"
  12. C "github.com/sagernet/sing-box/constant"
  13. "github.com/sagernet/sing/common"
  14. "github.com/sagernet/sing/common/bufio"
  15. "github.com/sagernet/sing/common/canceler"
  16. E "github.com/sagernet/sing/common/exceptions"
  17. "github.com/sagernet/sing/common/logger"
  18. M "github.com/sagernet/sing/common/metadata"
  19. N "github.com/sagernet/sing/common/network"
  20. "github.com/sagernet/sing/common/x/list"
  21. )
  22. var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
  23. type ConnectionManager struct {
  24. logger logger.ContextLogger
  25. access sync.Mutex
  26. connections list.List[io.Closer]
  27. }
  28. func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
  29. return &ConnectionManager{
  30. logger: logger,
  31. }
  32. }
  33. func (m *ConnectionManager) Start(stage adapter.StartStage) error {
  34. return nil
  35. }
  36. func (m *ConnectionManager) Close() error {
  37. m.access.Lock()
  38. defer m.access.Unlock()
  39. for element := m.connections.Front(); element != nil; element = element.Next() {
  40. common.Close(element.Value)
  41. }
  42. m.connections.Init()
  43. return nil
  44. }
  45. func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
  46. ctx = adapter.WithContext(ctx, &metadata)
  47. var (
  48. remoteConn net.Conn
  49. err error
  50. )
  51. if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
  52. remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  53. } else {
  54. remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
  55. }
  56. if err != nil {
  57. err = E.Cause(err, "open outbound connection")
  58. N.CloseOnHandshakeFailure(conn, onClose, err)
  59. m.logger.ErrorContext(ctx, err)
  60. return
  61. }
  62. err = N.ReportConnHandshakeSuccess(conn, remoteConn)
  63. if err != nil {
  64. err = E.Cause(err, "report handshake success")
  65. remoteConn.Close()
  66. N.CloseOnHandshakeFailure(conn, onClose, err)
  67. m.logger.ErrorContext(ctx, err)
  68. return
  69. }
  70. m.access.Lock()
  71. element := m.connections.PushBack(conn)
  72. m.access.Unlock()
  73. onClose = N.AppendClose(onClose, func(it error) {
  74. m.access.Lock()
  75. defer m.access.Unlock()
  76. m.connections.Remove(element)
  77. })
  78. var done atomic.Bool
  79. go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
  80. go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
  81. }
  82. func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
  83. ctx = adapter.WithContext(ctx, &metadata)
  84. var (
  85. remotePacketConn net.PacketConn
  86. remoteConn net.Conn
  87. destinationAddress netip.Addr
  88. err error
  89. )
  90. if metadata.UDPConnect {
  91. parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer)
  92. if len(metadata.DestinationAddresses) > 0 {
  93. if isParallelDialer {
  94. remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  95. } else {
  96. remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
  97. }
  98. } else if metadata.Destination.IsIP() {
  99. if isParallelDialer {
  100. remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  101. } else {
  102. remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
  103. }
  104. } else {
  105. remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
  106. }
  107. if err != nil {
  108. N.CloseOnHandshakeFailure(conn, onClose, err)
  109. m.logger.ErrorContext(ctx, "open outbound packet connection: ", err)
  110. return
  111. }
  112. remotePacketConn = bufio.NewUnbindPacketConn(remoteConn)
  113. connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr())
  114. if connRemoteAddr != metadata.Destination.Addr {
  115. destinationAddress = connRemoteAddr
  116. }
  117. } else {
  118. if len(metadata.DestinationAddresses) > 0 {
  119. remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  120. } else {
  121. remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination)
  122. }
  123. if err != nil {
  124. N.CloseOnHandshakeFailure(conn, onClose, err)
  125. m.logger.ErrorContext(ctx, "listen outbound packet connection: ", err)
  126. return
  127. }
  128. }
  129. err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn)
  130. if err != nil {
  131. conn.Close()
  132. remotePacketConn.Close()
  133. m.logger.ErrorContext(ctx, "report handshake success: ", err)
  134. return
  135. }
  136. if destinationAddress.IsValid() {
  137. var originDestination M.Socksaddr
  138. if metadata.RouteOriginalDestination.IsValid() {
  139. originDestination = metadata.RouteOriginalDestination
  140. } else {
  141. originDestination = metadata.Destination
  142. }
  143. if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
  144. if metadata.UDPDisableDomainUnmapping {
  145. remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
  146. } else {
  147. remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
  148. }
  149. }
  150. if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
  151. natConn.UpdateDestination(destinationAddress)
  152. }
  153. }
  154. var udpTimeout time.Duration
  155. if metadata.UDPTimeout > 0 {
  156. udpTimeout = metadata.UDPTimeout
  157. } else {
  158. protocol := metadata.Protocol
  159. if protocol == "" {
  160. protocol = C.PortProtocols[metadata.Destination.Port]
  161. }
  162. if protocol != "" {
  163. udpTimeout = C.ProtocolTimeouts[protocol]
  164. }
  165. }
  166. if udpTimeout > 0 {
  167. ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout)
  168. }
  169. destination := bufio.NewPacketConn(remotePacketConn)
  170. m.access.Lock()
  171. element := m.connections.PushBack(conn)
  172. m.access.Unlock()
  173. onClose = N.AppendClose(onClose, func(it error) {
  174. m.access.Lock()
  175. defer m.access.Unlock()
  176. m.connections.Remove(element)
  177. })
  178. var done atomic.Bool
  179. go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
  180. go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
  181. }
  182. func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
  183. originSource := source
  184. originDestination := destination
  185. var readCounters, writeCounters []N.CountFunc
  186. for {
  187. source, readCounters = N.UnwrapCountReader(source, readCounters)
  188. destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
  189. if cachedSrc, isCached := source.(N.CachedReader); isCached {
  190. cachedBuffer := cachedSrc.ReadCached()
  191. if cachedBuffer != nil {
  192. dataLen := cachedBuffer.Len()
  193. _, err := destination.Write(cachedBuffer.Bytes())
  194. cachedBuffer.Release()
  195. if err != nil {
  196. if done.Swap(true) {
  197. onClose(err)
  198. }
  199. common.Close(originSource, originDestination)
  200. if !direction {
  201. m.logger.ErrorContext(ctx, "connection upload payload: ", err)
  202. } else {
  203. m.logger.ErrorContext(ctx, "connection download payload: ", err)
  204. }
  205. return
  206. }
  207. for _, counter := range readCounters {
  208. counter(int64(dataLen))
  209. }
  210. for _, counter := range writeCounters {
  211. counter(int64(dataLen))
  212. }
  213. }
  214. continue
  215. }
  216. break
  217. }
  218. _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
  219. if err != nil {
  220. common.Close(originDestination)
  221. } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
  222. err = duplexDst.CloseWrite()
  223. if err != nil {
  224. common.Close(originSource, originDestination)
  225. }
  226. } else {
  227. common.Close(originDestination)
  228. }
  229. if done.Swap(true) {
  230. onClose(err)
  231. common.Close(originSource, originDestination)
  232. }
  233. if !direction {
  234. if err == nil {
  235. m.logger.DebugContext(ctx, "connection upload finished")
  236. } else if !E.IsClosedOrCanceled(err) {
  237. m.logger.ErrorContext(ctx, "connection upload closed: ", err)
  238. } else {
  239. m.logger.TraceContext(ctx, "connection upload closed")
  240. }
  241. } else {
  242. if err == nil {
  243. m.logger.DebugContext(ctx, "connection download finished")
  244. } else if !E.IsClosedOrCanceled(err) {
  245. m.logger.ErrorContext(ctx, "connection download closed: ", err)
  246. } else {
  247. m.logger.TraceContext(ctx, "connection download closed")
  248. }
  249. }
  250. }
  251. func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
  252. _, err := bufio.CopyPacket(destination, source)
  253. if !direction {
  254. if E.IsClosedOrCanceled(err) {
  255. m.logger.TraceContext(ctx, "packet upload closed")
  256. } else {
  257. m.logger.DebugContext(ctx, "packet upload closed: ", err)
  258. }
  259. } else {
  260. if E.IsClosedOrCanceled(err) {
  261. m.logger.TraceContext(ctx, "packet download closed")
  262. } else {
  263. m.logger.DebugContext(ctx, "packet download closed: ", err)
  264. }
  265. }
  266. if !done.Swap(true) {
  267. onClose(err)
  268. }
  269. common.Close(source, destination)
  270. }