conn.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. package route
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "net/netip"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. "github.com/sagernet/sing-box/adapter"
  11. "github.com/sagernet/sing-box/common/dialer"
  12. "github.com/sagernet/sing-box/common/tlsfragment"
  13. C "github.com/sagernet/sing-box/constant"
  14. "github.com/sagernet/sing/common"
  15. "github.com/sagernet/sing/common/bufio"
  16. "github.com/sagernet/sing/common/canceler"
  17. E "github.com/sagernet/sing/common/exceptions"
  18. "github.com/sagernet/sing/common/logger"
  19. M "github.com/sagernet/sing/common/metadata"
  20. N "github.com/sagernet/sing/common/network"
  21. "github.com/sagernet/sing/common/x/list"
  22. )
  23. var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
  24. type ConnectionManager struct {
  25. logger logger.ContextLogger
  26. access sync.Mutex
  27. connections list.List[io.Closer]
  28. }
  29. func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
  30. return &ConnectionManager{
  31. logger: logger,
  32. }
  33. }
  34. func (m *ConnectionManager) Start(stage adapter.StartStage) error {
  35. return nil
  36. }
  37. func (m *ConnectionManager) Close() error {
  38. m.access.Lock()
  39. defer m.access.Unlock()
  40. for element := m.connections.Front(); element != nil; element = element.Next() {
  41. common.Close(element.Value)
  42. }
  43. m.connections.Init()
  44. return nil
  45. }
  46. func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
  47. ctx = adapter.WithContext(ctx, &metadata)
  48. var (
  49. remoteConn net.Conn
  50. err error
  51. )
  52. if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
  53. remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  54. } else {
  55. remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
  56. }
  57. if err != nil {
  58. err = E.Cause(err, "open outbound connection")
  59. N.CloseOnHandshakeFailure(conn, onClose, err)
  60. m.logger.ErrorContext(ctx, err)
  61. return
  62. }
  63. err = N.ReportConnHandshakeSuccess(conn, remoteConn)
  64. if err != nil {
  65. err = E.Cause(err, "report handshake success")
  66. remoteConn.Close()
  67. N.CloseOnHandshakeFailure(conn, onClose, err)
  68. m.logger.ErrorContext(ctx, err)
  69. return
  70. }
  71. if metadata.TLSFragment {
  72. fallbackDelay := metadata.TLSFragmentFallbackDelay
  73. if fallbackDelay == 0 {
  74. fallbackDelay = C.TLSFragmentFallbackDelay
  75. }
  76. var newConn *tf.Conn
  77. newConn, err = tf.NewConn(remoteConn, ctx, fallbackDelay)
  78. if err != nil {
  79. conn.Close()
  80. remoteConn.Close()
  81. m.logger.ErrorContext(ctx, err)
  82. return
  83. }
  84. remoteConn = newConn
  85. }
  86. m.access.Lock()
  87. element := m.connections.PushBack(conn)
  88. m.access.Unlock()
  89. onClose = N.AppendClose(onClose, func(it error) {
  90. m.access.Lock()
  91. defer m.access.Unlock()
  92. m.connections.Remove(element)
  93. })
  94. var done atomic.Bool
  95. go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
  96. go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
  97. }
  98. func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
  99. ctx = adapter.WithContext(ctx, &metadata)
  100. var (
  101. remotePacketConn net.PacketConn
  102. remoteConn net.Conn
  103. destinationAddress netip.Addr
  104. err error
  105. )
  106. if metadata.UDPConnect {
  107. parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer)
  108. if len(metadata.DestinationAddresses) > 0 {
  109. if isParallelDialer {
  110. remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  111. } else {
  112. remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
  113. }
  114. } else if metadata.Destination.IsIP() {
  115. if isParallelDialer {
  116. remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  117. } else {
  118. remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
  119. }
  120. } else {
  121. remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
  122. }
  123. if err != nil {
  124. N.CloseOnHandshakeFailure(conn, onClose, err)
  125. m.logger.ErrorContext(ctx, "open outbound packet connection: ", err)
  126. return
  127. }
  128. remotePacketConn = bufio.NewUnbindPacketConn(remoteConn)
  129. connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr())
  130. if connRemoteAddr != metadata.Destination.Addr {
  131. destinationAddress = connRemoteAddr
  132. }
  133. } else {
  134. if len(metadata.DestinationAddresses) > 0 {
  135. remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  136. } else {
  137. remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination)
  138. }
  139. if err != nil {
  140. N.CloseOnHandshakeFailure(conn, onClose, err)
  141. m.logger.ErrorContext(ctx, "listen outbound packet connection: ", err)
  142. return
  143. }
  144. }
  145. err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn)
  146. if err != nil {
  147. conn.Close()
  148. remotePacketConn.Close()
  149. m.logger.ErrorContext(ctx, "report handshake success: ", err)
  150. return
  151. }
  152. if destinationAddress.IsValid() {
  153. var originDestination M.Socksaddr
  154. if metadata.RouteOriginalDestination.IsValid() {
  155. originDestination = metadata.RouteOriginalDestination
  156. } else {
  157. originDestination = metadata.Destination
  158. }
  159. if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
  160. if metadata.UDPDisableDomainUnmapping {
  161. remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
  162. } else {
  163. remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
  164. }
  165. }
  166. if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
  167. natConn.UpdateDestination(destinationAddress)
  168. }
  169. } else if metadata.RouteOriginalDestination.IsValid() && metadata.RouteOriginalDestination != metadata.Destination {
  170. if metadata.UDPDisableDomainUnmapping {
  171. remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), metadata.Destination, metadata.RouteOriginalDestination)
  172. } else {
  173. remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), metadata.Destination, metadata.RouteOriginalDestination)
  174. }
  175. }
  176. var udpTimeout time.Duration
  177. if metadata.UDPTimeout > 0 {
  178. udpTimeout = metadata.UDPTimeout
  179. } else {
  180. protocol := metadata.Protocol
  181. if protocol == "" {
  182. protocol = C.PortProtocols[metadata.Destination.Port]
  183. }
  184. if protocol != "" {
  185. udpTimeout = C.ProtocolTimeouts[protocol]
  186. }
  187. }
  188. if udpTimeout > 0 {
  189. ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout)
  190. }
  191. destination := bufio.NewPacketConn(remotePacketConn)
  192. m.access.Lock()
  193. element := m.connections.PushBack(conn)
  194. m.access.Unlock()
  195. onClose = N.AppendClose(onClose, func(it error) {
  196. m.access.Lock()
  197. defer m.access.Unlock()
  198. m.connections.Remove(element)
  199. })
  200. var done atomic.Bool
  201. go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
  202. go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
  203. }
  204. func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
  205. originSource := source
  206. originDestination := destination
  207. var readCounters, writeCounters []N.CountFunc
  208. for {
  209. source, readCounters = N.UnwrapCountReader(source, readCounters)
  210. destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
  211. if cachedSrc, isCached := source.(N.CachedReader); isCached {
  212. cachedBuffer := cachedSrc.ReadCached()
  213. if cachedBuffer != nil {
  214. dataLen := cachedBuffer.Len()
  215. _, err := destination.Write(cachedBuffer.Bytes())
  216. cachedBuffer.Release()
  217. if err != nil {
  218. if done.Swap(true) {
  219. onClose(err)
  220. }
  221. common.Close(originSource, originDestination)
  222. if !direction {
  223. m.logger.ErrorContext(ctx, "connection upload payload: ", err)
  224. } else {
  225. m.logger.ErrorContext(ctx, "connection download payload: ", err)
  226. }
  227. return
  228. }
  229. for _, counter := range readCounters {
  230. counter(int64(dataLen))
  231. }
  232. for _, counter := range writeCounters {
  233. counter(int64(dataLen))
  234. }
  235. }
  236. continue
  237. }
  238. break
  239. }
  240. if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destination); isEarlyConn && earlyConn.NeedHandshake() {
  241. _, err := destination.Write(nil)
  242. if err != nil {
  243. if !direction {
  244. m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
  245. } else {
  246. m.logger.ErrorContext(ctx, "connection download handshake: ", err)
  247. }
  248. return
  249. }
  250. }
  251. _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
  252. if err != nil {
  253. common.Close(originDestination)
  254. } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
  255. err = duplexDst.CloseWrite()
  256. if err != nil {
  257. common.Close(originSource, originDestination)
  258. }
  259. } else {
  260. common.Close(originDestination)
  261. }
  262. if done.Swap(true) {
  263. onClose(err)
  264. common.Close(originSource, originDestination)
  265. }
  266. if !direction {
  267. if err == nil {
  268. m.logger.DebugContext(ctx, "connection upload finished")
  269. } else if !E.IsClosedOrCanceled(err) {
  270. m.logger.ErrorContext(ctx, "connection upload closed: ", err)
  271. } else {
  272. m.logger.TraceContext(ctx, "connection upload closed")
  273. }
  274. } else {
  275. if err == nil {
  276. m.logger.DebugContext(ctx, "connection download finished")
  277. } else if !E.IsClosedOrCanceled(err) {
  278. m.logger.ErrorContext(ctx, "connection download closed: ", err)
  279. } else {
  280. m.logger.TraceContext(ctx, "connection download closed")
  281. }
  282. }
  283. }
  284. func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
  285. _, err := bufio.CopyPacket(destination, source)
  286. if !direction {
  287. if E.IsClosedOrCanceled(err) {
  288. m.logger.TraceContext(ctx, "packet upload closed")
  289. } else {
  290. m.logger.DebugContext(ctx, "packet upload closed: ", err)
  291. }
  292. } else {
  293. if E.IsClosedOrCanceled(err) {
  294. m.logger.TraceContext(ctx, "packet download closed")
  295. } else {
  296. m.logger.DebugContext(ctx, "packet download closed: ", err)
  297. }
  298. }
  299. if !done.Swap(true) {
  300. onClose(err)
  301. }
  302. common.Close(source, destination)
  303. }