conn.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. package route
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "net"
  7. "net/netip"
  8. "os"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. "github.com/sagernet/sing-box/adapter"
  13. "github.com/sagernet/sing-box/common/dialer"
  14. "github.com/sagernet/sing-box/common/tlsfragment"
  15. C "github.com/sagernet/sing-box/constant"
  16. "github.com/sagernet/sing/common"
  17. "github.com/sagernet/sing/common/buf"
  18. "github.com/sagernet/sing/common/bufio"
  19. "github.com/sagernet/sing/common/canceler"
  20. E "github.com/sagernet/sing/common/exceptions"
  21. "github.com/sagernet/sing/common/logger"
  22. M "github.com/sagernet/sing/common/metadata"
  23. N "github.com/sagernet/sing/common/network"
  24. "github.com/sagernet/sing/common/x/list"
  25. )
  26. var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
  27. type ConnectionManager struct {
  28. logger logger.ContextLogger
  29. access sync.Mutex
  30. connections list.List[io.Closer]
  31. }
  32. func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
  33. return &ConnectionManager{
  34. logger: logger,
  35. }
  36. }
  37. func (m *ConnectionManager) Start(stage adapter.StartStage) error {
  38. return nil
  39. }
  40. func (m *ConnectionManager) Close() error {
  41. m.access.Lock()
  42. defer m.access.Unlock()
  43. for element := m.connections.Front(); element != nil; element = element.Next() {
  44. common.Close(element.Value)
  45. }
  46. m.connections.Init()
  47. return nil
  48. }
  49. func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
  50. ctx = adapter.WithContext(ctx, &metadata)
  51. var (
  52. remoteConn net.Conn
  53. err error
  54. )
  55. if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
  56. remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  57. } else {
  58. remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
  59. }
  60. if err != nil {
  61. err = E.Cause(err, "open outbound connection")
  62. N.CloseOnHandshakeFailure(conn, onClose, err)
  63. m.logger.ErrorContext(ctx, err)
  64. return
  65. }
  66. err = N.ReportConnHandshakeSuccess(conn, remoteConn)
  67. if err != nil {
  68. err = E.Cause(err, "report handshake success")
  69. remoteConn.Close()
  70. N.CloseOnHandshakeFailure(conn, onClose, err)
  71. m.logger.ErrorContext(ctx, err)
  72. return
  73. }
  74. if metadata.TLSFragment {
  75. fallbackDelay := metadata.TLSFragmentFallbackDelay
  76. if fallbackDelay == 0 {
  77. fallbackDelay = C.TLSFragmentFallbackDelay
  78. }
  79. var newConn *tf.Conn
  80. newConn, err = tf.NewConn(remoteConn, ctx, fallbackDelay)
  81. if err != nil {
  82. conn.Close()
  83. remoteConn.Close()
  84. m.logger.ErrorContext(ctx, err)
  85. return
  86. }
  87. remoteConn = newConn
  88. }
  89. m.access.Lock()
  90. element := m.connections.PushBack(conn)
  91. m.access.Unlock()
  92. onClose = N.AppendClose(onClose, func(it error) {
  93. m.access.Lock()
  94. defer m.access.Unlock()
  95. m.connections.Remove(element)
  96. })
  97. var done atomic.Bool
  98. go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
  99. go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
  100. }
  101. func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
  102. ctx = adapter.WithContext(ctx, &metadata)
  103. var (
  104. remotePacketConn net.PacketConn
  105. remoteConn net.Conn
  106. destinationAddress netip.Addr
  107. err error
  108. )
  109. if metadata.UDPConnect {
  110. parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer)
  111. if len(metadata.DestinationAddresses) > 0 {
  112. if isParallelDialer {
  113. remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  114. } else {
  115. remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
  116. }
  117. } else if metadata.Destination.IsIP() {
  118. if isParallelDialer {
  119. remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  120. } else {
  121. remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
  122. }
  123. } else {
  124. remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
  125. }
  126. if err != nil {
  127. N.CloseOnHandshakeFailure(conn, onClose, err)
  128. m.logger.ErrorContext(ctx, "open outbound packet connection: ", err)
  129. return
  130. }
  131. remotePacketConn = bufio.NewUnbindPacketConn(remoteConn)
  132. connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr())
  133. if connRemoteAddr != metadata.Destination.Addr {
  134. destinationAddress = connRemoteAddr
  135. }
  136. } else {
  137. if len(metadata.DestinationAddresses) > 0 {
  138. remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  139. } else {
  140. remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination)
  141. }
  142. if err != nil {
  143. N.CloseOnHandshakeFailure(conn, onClose, err)
  144. m.logger.ErrorContext(ctx, "listen outbound packet connection: ", err)
  145. return
  146. }
  147. }
  148. err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn)
  149. if err != nil {
  150. conn.Close()
  151. remotePacketConn.Close()
  152. m.logger.ErrorContext(ctx, "report handshake success: ", err)
  153. return
  154. }
  155. if destinationAddress.IsValid() {
  156. var originDestination M.Socksaddr
  157. if metadata.RouteOriginalDestination.IsValid() {
  158. originDestination = metadata.RouteOriginalDestination
  159. } else {
  160. originDestination = metadata.Destination
  161. }
  162. if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
  163. natConn.UpdateDestination(destinationAddress)
  164. } else if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
  165. if metadata.UDPDisableDomainUnmapping {
  166. remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
  167. } else {
  168. remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
  169. }
  170. }
  171. } else if metadata.RouteOriginalDestination.IsValid() && metadata.RouteOriginalDestination != metadata.Destination {
  172. remotePacketConn = bufio.NewDestinationNATPacketConn(bufio.NewPacketConn(remotePacketConn), metadata.Destination, metadata.RouteOriginalDestination)
  173. }
  174. var udpTimeout time.Duration
  175. if metadata.UDPTimeout > 0 {
  176. udpTimeout = metadata.UDPTimeout
  177. } else {
  178. protocol := metadata.Protocol
  179. if protocol == "" {
  180. protocol = C.PortProtocols[metadata.Destination.Port]
  181. }
  182. if protocol != "" {
  183. udpTimeout = C.ProtocolTimeouts[protocol]
  184. }
  185. }
  186. if udpTimeout > 0 {
  187. ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout)
  188. }
  189. destination := bufio.NewPacketConn(remotePacketConn)
  190. m.access.Lock()
  191. element := m.connections.PushBack(conn)
  192. m.access.Unlock()
  193. onClose = N.AppendClose(onClose, func(it error) {
  194. m.access.Lock()
  195. defer m.access.Unlock()
  196. m.connections.Remove(element)
  197. })
  198. var done atomic.Bool
  199. go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
  200. go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
  201. }
  202. func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
  203. var (
  204. sourceReader io.Reader = source
  205. destinationWriter io.Writer = destination
  206. )
  207. var readCounters, writeCounters []N.CountFunc
  208. for {
  209. sourceReader, readCounters = N.UnwrapCountReader(sourceReader, readCounters)
  210. destinationWriter, writeCounters = N.UnwrapCountWriter(destinationWriter, writeCounters)
  211. if cachedSrc, isCached := sourceReader.(N.CachedReader); isCached {
  212. cachedBuffer := cachedSrc.ReadCached()
  213. if cachedBuffer != nil {
  214. dataLen := cachedBuffer.Len()
  215. _, err := destination.Write(cachedBuffer.Bytes())
  216. cachedBuffer.Release()
  217. if err != nil {
  218. if done.Swap(true) {
  219. onClose(err)
  220. }
  221. common.Close(source, destination)
  222. if !direction {
  223. m.logger.ErrorContext(ctx, "connection upload payload: ", err)
  224. } else {
  225. m.logger.ErrorContext(ctx, "connection download payload: ", err)
  226. }
  227. return
  228. }
  229. for _, counter := range readCounters {
  230. counter(int64(dataLen))
  231. }
  232. for _, counter := range writeCounters {
  233. counter(int64(dataLen))
  234. }
  235. }
  236. continue
  237. }
  238. break
  239. }
  240. if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destinationWriter); isEarlyConn && earlyConn.NeedHandshake() {
  241. err := m.connectionCopyEarly(source, destination)
  242. if err != nil {
  243. if done.Swap(true) {
  244. onClose(err)
  245. }
  246. common.Close(source, destination)
  247. if !direction {
  248. m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
  249. } else {
  250. m.logger.ErrorContext(ctx, "connection download handshake: ", err)
  251. }
  252. return
  253. }
  254. }
  255. _, err := bufio.CopyWithCounters(destinationWriter, sourceReader, source, readCounters, writeCounters)
  256. if err != nil {
  257. common.Close(source, destination)
  258. } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
  259. err = duplexDst.CloseWrite()
  260. if err != nil {
  261. common.Close(source, destination)
  262. }
  263. } else {
  264. destination.Close()
  265. }
  266. if done.Swap(true) {
  267. onClose(err)
  268. common.Close(source, destination)
  269. }
  270. if !direction {
  271. if err == nil {
  272. m.logger.DebugContext(ctx, "connection upload finished")
  273. } else if !E.IsClosedOrCanceled(err) {
  274. m.logger.ErrorContext(ctx, "connection upload closed: ", err)
  275. } else {
  276. m.logger.TraceContext(ctx, "connection upload closed")
  277. }
  278. } else {
  279. if err == nil {
  280. m.logger.DebugContext(ctx, "connection download finished")
  281. } else if !E.IsClosedOrCanceled(err) {
  282. m.logger.ErrorContext(ctx, "connection download closed: ", err)
  283. } else {
  284. m.logger.TraceContext(ctx, "connection download closed")
  285. }
  286. }
  287. }
  288. func (m *ConnectionManager) connectionCopyEarly(source net.Conn, destination io.Writer) error {
  289. payload := buf.NewPacket()
  290. defer payload.Release()
  291. err := source.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
  292. if err != nil {
  293. if err == os.ErrInvalid {
  294. return common.Error(destination.Write(nil))
  295. }
  296. return err
  297. }
  298. _, err = payload.ReadOnceFrom(source)
  299. if err != nil && !(E.IsTimeout(err) || errors.Is(err, io.EOF)) {
  300. return E.Cause(err, "read payload")
  301. }
  302. _ = source.SetReadDeadline(time.Time{})
  303. _, err = destination.Write(payload.Bytes())
  304. if err != nil {
  305. return E.Cause(err, "write payload")
  306. }
  307. return nil
  308. }
  309. func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
  310. _, err := bufio.CopyPacket(destination, source)
  311. if !direction {
  312. if err == nil {
  313. m.logger.DebugContext(ctx, "packet upload finished")
  314. } else if E.IsClosedOrCanceled(err) {
  315. m.logger.TraceContext(ctx, "packet upload closed")
  316. } else {
  317. m.logger.DebugContext(ctx, "packet upload closed: ", err)
  318. }
  319. } else {
  320. if err == nil {
  321. m.logger.DebugContext(ctx, "packet download finished")
  322. } else if E.IsClosedOrCanceled(err) {
  323. m.logger.TraceContext(ctx, "packet download closed")
  324. } else {
  325. m.logger.DebugContext(ctx, "packet download closed: ", err)
  326. }
  327. }
  328. if !done.Swap(true) {
  329. onClose(err)
  330. }
  331. common.Close(source, destination)
  332. }