conn.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. package route
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "net"
  7. "net/netip"
  8. "os"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. "github.com/sagernet/sing-box/adapter"
  13. "github.com/sagernet/sing-box/common/dialer"
  14. C "github.com/sagernet/sing-box/constant"
  15. "github.com/sagernet/sing/common"
  16. "github.com/sagernet/sing/common/buf"
  17. "github.com/sagernet/sing/common/bufio"
  18. "github.com/sagernet/sing/common/canceler"
  19. E "github.com/sagernet/sing/common/exceptions"
  20. "github.com/sagernet/sing/common/logger"
  21. M "github.com/sagernet/sing/common/metadata"
  22. N "github.com/sagernet/sing/common/network"
  23. "github.com/sagernet/sing/common/x/list"
  24. )
  25. var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
  26. type ConnectionManager struct {
  27. logger logger.ContextLogger
  28. access sync.Mutex
  29. connections list.List[io.Closer]
  30. }
  31. func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
  32. return &ConnectionManager{
  33. logger: logger,
  34. }
  35. }
  36. func (m *ConnectionManager) Start(stage adapter.StartStage) error {
  37. return nil
  38. }
  39. func (m *ConnectionManager) Close() error {
  40. m.access.Lock()
  41. defer m.access.Unlock()
  42. for element := m.connections.Front(); element != nil; element = element.Next() {
  43. common.Close(element.Value)
  44. }
  45. m.connections.Init()
  46. return nil
  47. }
  48. func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
  49. ctx = adapter.WithContext(ctx, &metadata)
  50. var (
  51. remoteConn net.Conn
  52. err error
  53. )
  54. if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
  55. remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  56. } else {
  57. remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
  58. }
  59. if err != nil {
  60. err = E.Cause(err, "open outbound connection")
  61. N.CloseOnHandshakeFailure(conn, onClose, err)
  62. m.logger.ErrorContext(ctx, err)
  63. return
  64. }
  65. err = N.ReportConnHandshakeSuccess(conn, remoteConn)
  66. if err != nil {
  67. err = E.Cause(err, "report handshake success")
  68. remoteConn.Close()
  69. N.CloseOnHandshakeFailure(conn, onClose, err)
  70. m.logger.ErrorContext(ctx, err)
  71. return
  72. }
  73. m.access.Lock()
  74. element := m.connections.PushBack(conn)
  75. m.access.Unlock()
  76. onClose = N.AppendClose(onClose, func(it error) {
  77. m.access.Lock()
  78. defer m.access.Unlock()
  79. m.connections.Remove(element)
  80. })
  81. var done atomic.Bool
  82. go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
  83. go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
  84. }
  85. func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
  86. ctx = adapter.WithContext(ctx, &metadata)
  87. var (
  88. remotePacketConn net.PacketConn
  89. remoteConn net.Conn
  90. destinationAddress netip.Addr
  91. err error
  92. )
  93. if metadata.UDPConnect {
  94. parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer)
  95. if len(metadata.DestinationAddresses) > 0 {
  96. if isParallelDialer {
  97. remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  98. } else {
  99. remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
  100. }
  101. } else if metadata.Destination.IsIP() {
  102. if isParallelDialer {
  103. remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  104. } else {
  105. remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
  106. }
  107. } else {
  108. remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
  109. }
  110. if err != nil {
  111. N.CloseOnHandshakeFailure(conn, onClose, err)
  112. m.logger.ErrorContext(ctx, "open outbound packet connection: ", err)
  113. return
  114. }
  115. remotePacketConn = bufio.NewUnbindPacketConn(remoteConn)
  116. connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr())
  117. if connRemoteAddr != metadata.Destination.Addr {
  118. destinationAddress = connRemoteAddr
  119. }
  120. } else {
  121. if len(metadata.DestinationAddresses) > 0 {
  122. remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  123. } else {
  124. remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination)
  125. }
  126. if err != nil {
  127. N.CloseOnHandshakeFailure(conn, onClose, err)
  128. m.logger.ErrorContext(ctx, "listen outbound packet connection: ", err)
  129. return
  130. }
  131. }
  132. err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn)
  133. if err != nil {
  134. conn.Close()
  135. remotePacketConn.Close()
  136. m.logger.ErrorContext(ctx, "report handshake success: ", err)
  137. return
  138. }
  139. if destinationAddress.IsValid() {
  140. var originDestination M.Socksaddr
  141. if metadata.RouteOriginalDestination.IsValid() {
  142. originDestination = metadata.RouteOriginalDestination
  143. } else {
  144. originDestination = metadata.Destination
  145. }
  146. if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
  147. natConn.UpdateDestination(destinationAddress)
  148. } else if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
  149. if metadata.UDPDisableDomainUnmapping {
  150. remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
  151. } else {
  152. remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
  153. }
  154. }
  155. } else if metadata.RouteOriginalDestination.IsValid() && metadata.RouteOriginalDestination != metadata.Destination {
  156. remotePacketConn = bufio.NewDestinationNATPacketConn(bufio.NewPacketConn(remotePacketConn), metadata.Destination, metadata.RouteOriginalDestination)
  157. }
  158. var udpTimeout time.Duration
  159. if metadata.UDPTimeout > 0 {
  160. udpTimeout = metadata.UDPTimeout
  161. } else {
  162. protocol := metadata.Protocol
  163. if protocol == "" {
  164. protocol = C.PortProtocols[metadata.Destination.Port]
  165. }
  166. if protocol != "" {
  167. udpTimeout = C.ProtocolTimeouts[protocol]
  168. }
  169. }
  170. if udpTimeout > 0 {
  171. ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout)
  172. }
  173. destination := bufio.NewPacketConn(remotePacketConn)
  174. m.access.Lock()
  175. element := m.connections.PushBack(conn)
  176. m.access.Unlock()
  177. onClose = N.AppendClose(onClose, func(it error) {
  178. m.access.Lock()
  179. defer m.access.Unlock()
  180. m.connections.Remove(element)
  181. })
  182. var done atomic.Bool
  183. go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
  184. go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
  185. }
  186. func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
  187. var (
  188. sourceReader io.Reader = source
  189. destinationWriter io.Writer = destination
  190. )
  191. var readCounters, writeCounters []N.CountFunc
  192. for {
  193. sourceReader, readCounters = N.UnwrapCountReader(sourceReader, readCounters)
  194. destinationWriter, writeCounters = N.UnwrapCountWriter(destinationWriter, writeCounters)
  195. if cachedSrc, isCached := sourceReader.(N.CachedReader); isCached {
  196. cachedBuffer := cachedSrc.ReadCached()
  197. if cachedBuffer != nil {
  198. dataLen := cachedBuffer.Len()
  199. _, err := destination.Write(cachedBuffer.Bytes())
  200. cachedBuffer.Release()
  201. if err != nil {
  202. if done.Swap(true) {
  203. onClose(err)
  204. }
  205. common.Close(source, destination)
  206. if !direction {
  207. m.logger.ErrorContext(ctx, "connection upload payload: ", err)
  208. } else {
  209. m.logger.ErrorContext(ctx, "connection download payload: ", err)
  210. }
  211. return
  212. }
  213. for _, counter := range readCounters {
  214. counter(int64(dataLen))
  215. }
  216. for _, counter := range writeCounters {
  217. counter(int64(dataLen))
  218. }
  219. }
  220. continue
  221. }
  222. break
  223. }
  224. if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](destinationWriter); isEarlyConn && earlyConn.NeedHandshake() {
  225. err := m.connectionCopyEarly(source, destination)
  226. if err != nil {
  227. if done.Swap(true) {
  228. onClose(err)
  229. }
  230. common.Close(source, destination)
  231. if !direction {
  232. m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
  233. } else {
  234. m.logger.ErrorContext(ctx, "connection download handshake: ", err)
  235. }
  236. return
  237. }
  238. }
  239. _, err := bufio.CopyWithCounters(destination, sourceReader, source, readCounters, writeCounters)
  240. if err != nil {
  241. common.Close(source, destination)
  242. } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
  243. err = duplexDst.CloseWrite()
  244. if err != nil {
  245. common.Close(source, destination)
  246. }
  247. } else {
  248. destination.Close()
  249. }
  250. if done.Swap(true) {
  251. onClose(err)
  252. common.Close(source, destination)
  253. }
  254. if !direction {
  255. if err == nil {
  256. m.logger.DebugContext(ctx, "connection upload finished")
  257. } else if !E.IsClosedOrCanceled(err) {
  258. m.logger.ErrorContext(ctx, "connection upload closed: ", err)
  259. } else {
  260. m.logger.TraceContext(ctx, "connection upload closed")
  261. }
  262. } else {
  263. if err == nil {
  264. m.logger.DebugContext(ctx, "connection download finished")
  265. } else if !E.IsClosedOrCanceled(err) {
  266. m.logger.ErrorContext(ctx, "connection download closed: ", err)
  267. } else {
  268. m.logger.TraceContext(ctx, "connection download closed")
  269. }
  270. }
  271. }
  272. func (m *ConnectionManager) connectionCopyEarly(source net.Conn, destination io.Writer) error {
  273. payload := buf.NewPacket()
  274. defer payload.Release()
  275. err := source.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
  276. if err != nil {
  277. if err == os.ErrInvalid {
  278. return common.Error(destination.Write(nil))
  279. }
  280. return err
  281. }
  282. _, err = payload.ReadOnceFrom(source)
  283. if err != nil && !(E.IsTimeout(err) || errors.Is(err, io.EOF)) {
  284. return E.Cause(err, "read payload")
  285. }
  286. _ = source.SetReadDeadline(time.Time{})
  287. _, err = destination.Write(payload.Bytes())
  288. if err != nil {
  289. return E.Cause(err, "write payload")
  290. }
  291. return nil
  292. }
  293. func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
  294. _, err := bufio.CopyPacket(destination, source)
  295. if !direction {
  296. if err == nil {
  297. m.logger.DebugContext(ctx, "packet upload finished")
  298. } else if E.IsClosedOrCanceled(err) {
  299. m.logger.TraceContext(ctx, "packet upload closed")
  300. } else {
  301. m.logger.DebugContext(ctx, "packet upload closed: ", err)
  302. }
  303. } else {
  304. if err == nil {
  305. m.logger.DebugContext(ctx, "packet download finished")
  306. } else if E.IsClosedOrCanceled(err) {
  307. m.logger.TraceContext(ctx, "packet download closed")
  308. } else {
  309. m.logger.DebugContext(ctx, "packet download closed: ", err)
  310. }
  311. }
  312. if !done.Swap(true) {
  313. onClose(err)
  314. }
  315. common.Close(source, destination)
  316. }