conn.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. package route
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "net"
  7. "net/netip"
  8. "os"
  9. "strings"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. "github.com/sagernet/sing-box/adapter"
  14. "github.com/sagernet/sing-box/common/dialer"
  15. "github.com/sagernet/sing-box/common/tlsfragment"
  16. C "github.com/sagernet/sing-box/constant"
  17. "github.com/sagernet/sing/common"
  18. "github.com/sagernet/sing/common/buf"
  19. "github.com/sagernet/sing/common/bufio"
  20. "github.com/sagernet/sing/common/canceler"
  21. E "github.com/sagernet/sing/common/exceptions"
  22. "github.com/sagernet/sing/common/logger"
  23. M "github.com/sagernet/sing/common/metadata"
  24. N "github.com/sagernet/sing/common/network"
  25. "github.com/sagernet/sing/common/x/list"
  26. )
  27. var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
  28. type ConnectionManager struct {
  29. logger logger.ContextLogger
  30. access sync.Mutex
  31. connections list.List[io.Closer]
  32. }
  33. func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
  34. return &ConnectionManager{
  35. logger: logger,
  36. }
  37. }
  38. func (m *ConnectionManager) Start(stage adapter.StartStage) error {
  39. return nil
  40. }
  41. func (m *ConnectionManager) Close() error {
  42. m.access.Lock()
  43. defer m.access.Unlock()
  44. for element := m.connections.Front(); element != nil; element = element.Next() {
  45. common.Close(element.Value)
  46. }
  47. m.connections.Init()
  48. return nil
  49. }
  50. func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
  51. ctx = adapter.WithContext(ctx, &metadata)
  52. var (
  53. remoteConn net.Conn
  54. err error
  55. )
  56. if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
  57. remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  58. } else {
  59. remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
  60. }
  61. if err != nil {
  62. var remoteString string
  63. if len(metadata.DestinationAddresses) > 0 {
  64. remoteString = "[" + strings.Join(common.Map(metadata.DestinationAddresses, netip.Addr.String), ",") + "]"
  65. } else {
  66. remoteString = metadata.Destination.String()
  67. }
  68. var dialerString string
  69. if outbound, isOutbound := this.(adapter.Outbound); isOutbound {
  70. dialerString = " using outbound/" + outbound.Type() + "[" + outbound.Tag() + "]"
  71. }
  72. err = E.Cause(err, "open connection to ", remoteString, dialerString)
  73. N.CloseOnHandshakeFailure(conn, onClose, err)
  74. m.logger.ErrorContext(ctx, err)
  75. return
  76. }
  77. err = N.ReportConnHandshakeSuccess(conn, remoteConn)
  78. if err != nil {
  79. err = E.Cause(err, "report handshake success")
  80. remoteConn.Close()
  81. N.CloseOnHandshakeFailure(conn, onClose, err)
  82. m.logger.ErrorContext(ctx, err)
  83. return
  84. }
  85. if metadata.TLSFragment || metadata.TLSRecordFragment {
  86. remoteConn = tf.NewConn(remoteConn, ctx, metadata.TLSFragment, metadata.TLSRecordFragment, metadata.TLSFragmentFallbackDelay)
  87. }
  88. m.access.Lock()
  89. element := m.connections.PushBack(conn)
  90. m.access.Unlock()
  91. onClose = N.AppendClose(onClose, func(it error) {
  92. m.access.Lock()
  93. defer m.access.Unlock()
  94. m.connections.Remove(element)
  95. })
  96. var done atomic.Bool
  97. m.preConnectionCopy(ctx, conn, remoteConn, false, &done, onClose)
  98. m.preConnectionCopy(ctx, remoteConn, conn, true, &done, onClose)
  99. go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
  100. go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
  101. }
  102. func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
  103. ctx = adapter.WithContext(ctx, &metadata)
  104. var (
  105. remotePacketConn net.PacketConn
  106. remoteConn net.Conn
  107. destinationAddress netip.Addr
  108. err error
  109. )
  110. if metadata.UDPConnect {
  111. parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer)
  112. if len(metadata.DestinationAddresses) > 0 {
  113. if isParallelDialer {
  114. remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  115. } else {
  116. remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
  117. }
  118. } else if metadata.Destination.IsIP() {
  119. if isParallelDialer {
  120. remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  121. } else {
  122. remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
  123. }
  124. } else {
  125. remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
  126. }
  127. if err != nil {
  128. var remoteString string
  129. if len(metadata.DestinationAddresses) > 0 {
  130. remoteString = "[" + strings.Join(common.Map(metadata.DestinationAddresses, netip.Addr.String), ",") + "]"
  131. } else {
  132. remoteString = metadata.Destination.String()
  133. }
  134. var dialerString string
  135. if outbound, isOutbound := this.(adapter.Outbound); isOutbound {
  136. dialerString = " using outbound/" + outbound.Type() + "[" + outbound.Tag() + "]"
  137. }
  138. err = E.Cause(err, "open packet connection to ", remoteString, dialerString)
  139. N.CloseOnHandshakeFailure(conn, onClose, err)
  140. m.logger.ErrorContext(ctx, err)
  141. return
  142. }
  143. remotePacketConn = bufio.NewUnbindPacketConn(remoteConn)
  144. connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr())
  145. if connRemoteAddr != metadata.Destination.Addr {
  146. destinationAddress = connRemoteAddr
  147. }
  148. } else {
  149. if len(metadata.DestinationAddresses) > 0 {
  150. remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
  151. } else {
  152. remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination)
  153. }
  154. if err != nil {
  155. var dialerString string
  156. if outbound, isOutbound := this.(adapter.Outbound); isOutbound {
  157. dialerString = " using outbound/" + outbound.Type() + "[" + outbound.Tag() + "]"
  158. }
  159. err = E.Cause(err, "listen packet connection using ", dialerString)
  160. N.CloseOnHandshakeFailure(conn, onClose, err)
  161. m.logger.ErrorContext(ctx, err)
  162. return
  163. }
  164. }
  165. err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn)
  166. if err != nil {
  167. conn.Close()
  168. remotePacketConn.Close()
  169. m.logger.ErrorContext(ctx, "report handshake success: ", err)
  170. return
  171. }
  172. if destinationAddress.IsValid() {
  173. var originDestination M.Socksaddr
  174. if metadata.RouteOriginalDestination.IsValid() {
  175. originDestination = metadata.RouteOriginalDestination
  176. } else {
  177. originDestination = metadata.Destination
  178. }
  179. if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
  180. natConn.UpdateDestination(destinationAddress)
  181. } else if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
  182. if metadata.UDPDisableDomainUnmapping {
  183. remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
  184. } else {
  185. remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
  186. }
  187. }
  188. } else if metadata.RouteOriginalDestination.IsValid() && metadata.RouteOriginalDestination != metadata.Destination {
  189. remotePacketConn = bufio.NewDestinationNATPacketConn(bufio.NewPacketConn(remotePacketConn), metadata.Destination, metadata.RouteOriginalDestination)
  190. }
  191. var udpTimeout time.Duration
  192. if metadata.UDPTimeout > 0 {
  193. udpTimeout = metadata.UDPTimeout
  194. } else {
  195. protocol := metadata.Protocol
  196. if protocol == "" {
  197. protocol = C.PortProtocols[metadata.Destination.Port]
  198. }
  199. if protocol != "" {
  200. udpTimeout = C.ProtocolTimeouts[protocol]
  201. }
  202. }
  203. if udpTimeout > 0 {
  204. ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout)
  205. }
  206. destination := bufio.NewPacketConn(remotePacketConn)
  207. m.access.Lock()
  208. element := m.connections.PushBack(conn)
  209. m.access.Unlock()
  210. onClose = N.AppendClose(onClose, func(it error) {
  211. m.access.Lock()
  212. defer m.access.Unlock()
  213. m.connections.Remove(element)
  214. })
  215. var done atomic.Bool
  216. go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
  217. go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
  218. }
  219. func (m *ConnectionManager) preConnectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
  220. readHandshake := N.NeedHandshakeForRead(source)
  221. writeHandshake := N.NeedHandshakeForWrite(destination)
  222. if readHandshake || writeHandshake {
  223. var err error
  224. for {
  225. err = m.connectionCopyEarlyWrite(source, destination, readHandshake, writeHandshake)
  226. if err == nil && N.NeedHandshakeForRead(source) {
  227. continue
  228. } else if E.IsMulti(err, os.ErrInvalid, context.DeadlineExceeded, io.EOF) {
  229. err = nil
  230. }
  231. break
  232. }
  233. if err != nil {
  234. if done.Swap(true) {
  235. onClose(err)
  236. }
  237. common.Close(source, destination)
  238. if !direction {
  239. m.logger.ErrorContext(ctx, "connection upload handshake: ", err)
  240. } else {
  241. m.logger.ErrorContext(ctx, "connection download handshake: ", err)
  242. }
  243. return
  244. }
  245. }
  246. }
  247. func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
  248. var (
  249. sourceReader io.Reader = source
  250. destinationWriter io.Writer = destination
  251. )
  252. var readCounters, writeCounters []N.CountFunc
  253. for {
  254. sourceReader, readCounters = N.UnwrapCountReader(sourceReader, readCounters)
  255. destinationWriter, writeCounters = N.UnwrapCountWriter(destinationWriter, writeCounters)
  256. if cachedSrc, isCached := sourceReader.(N.CachedReader); isCached {
  257. cachedBuffer := cachedSrc.ReadCached()
  258. if cachedBuffer != nil {
  259. dataLen := cachedBuffer.Len()
  260. _, err := destination.Write(cachedBuffer.Bytes())
  261. cachedBuffer.Release()
  262. if err != nil {
  263. if done.Swap(true) {
  264. onClose(err)
  265. }
  266. common.Close(source, destination)
  267. if !direction {
  268. m.logger.ErrorContext(ctx, "connection upload payload: ", err)
  269. } else {
  270. m.logger.ErrorContext(ctx, "connection download payload: ", err)
  271. }
  272. return
  273. }
  274. for _, counter := range readCounters {
  275. counter(int64(dataLen))
  276. }
  277. for _, counter := range writeCounters {
  278. counter(int64(dataLen))
  279. }
  280. }
  281. continue
  282. }
  283. break
  284. }
  285. _, err := bufio.CopyWithCounters(destinationWriter, sourceReader, source, readCounters, writeCounters, bufio.DefaultIncreaseBufferAfter, bufio.DefaultBatchSize)
  286. if err != nil {
  287. common.Close(source, destination)
  288. } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
  289. err = duplexDst.CloseWrite()
  290. if err != nil {
  291. common.Close(source, destination)
  292. }
  293. } else {
  294. destination.Close()
  295. }
  296. if done.Swap(true) {
  297. onClose(err)
  298. common.Close(source, destination)
  299. }
  300. if !direction {
  301. if err == nil {
  302. m.logger.DebugContext(ctx, "connection upload finished")
  303. } else if !E.IsClosedOrCanceled(err) {
  304. m.logger.ErrorContext(ctx, "connection upload closed: ", err)
  305. } else {
  306. m.logger.TraceContext(ctx, "connection upload closed")
  307. }
  308. } else {
  309. if err == nil {
  310. m.logger.DebugContext(ctx, "connection download finished")
  311. } else if !E.IsClosedOrCanceled(err) {
  312. m.logger.ErrorContext(ctx, "connection download closed: ", err)
  313. } else {
  314. m.logger.TraceContext(ctx, "connection download closed")
  315. }
  316. }
  317. }
  318. func (m *ConnectionManager) connectionCopyEarlyWrite(source net.Conn, destination io.Writer, readHandshake bool, writeHandshake bool) error {
  319. payload := buf.NewPacket()
  320. defer payload.Release()
  321. err := source.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
  322. if err != nil {
  323. if err == os.ErrInvalid {
  324. if writeHandshake {
  325. return common.Error(destination.Write(nil))
  326. }
  327. }
  328. return err
  329. }
  330. var (
  331. isTimeout bool
  332. isEOF bool
  333. )
  334. _, err = payload.ReadOnceFrom(source)
  335. if err != nil {
  336. if E.IsTimeout(err) {
  337. isTimeout = true
  338. } else if errors.Is(err, io.EOF) {
  339. isEOF = true
  340. } else {
  341. return E.Cause(err, "read payload")
  342. }
  343. }
  344. _ = source.SetReadDeadline(time.Time{})
  345. if !payload.IsEmpty() || writeHandshake {
  346. _, err = destination.Write(payload.Bytes())
  347. if err != nil {
  348. return E.Cause(err, "write payload")
  349. }
  350. }
  351. if isTimeout {
  352. return context.DeadlineExceeded
  353. } else if isEOF {
  354. return io.EOF
  355. }
  356. return nil
  357. }
  358. func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
  359. _, err := bufio.CopyPacket(destination, source)
  360. if !direction {
  361. if err == nil {
  362. m.logger.DebugContext(ctx, "packet upload finished")
  363. } else if E.IsClosedOrCanceled(err) {
  364. m.logger.TraceContext(ctx, "packet upload closed")
  365. } else {
  366. m.logger.DebugContext(ctx, "packet upload closed: ", err)
  367. }
  368. } else {
  369. if err == nil {
  370. m.logger.DebugContext(ctx, "packet download finished")
  371. } else if E.IsClosedOrCanceled(err) {
  372. m.logger.TraceContext(ctx, "packet download closed")
  373. } else {
  374. m.logger.DebugContext(ctx, "packet download closed: ", err)
  375. }
  376. }
  377. if !done.Swap(true) {
  378. onClose(err)
  379. }
  380. common.Close(source, destination)
  381. }