loopback_detect.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. package direct
  2. import (
  3. "net"
  4. "net/netip"
  5. "sync"
  6. "github.com/sagernet/sing-box/adapter"
  7. M "github.com/sagernet/sing/common/metadata"
  8. N "github.com/sagernet/sing/common/network"
  9. )
  10. type loopBackDetector struct {
  11. networkManager adapter.NetworkManager
  12. connAccess sync.RWMutex
  13. packetConnAccess sync.RWMutex
  14. connMap map[netip.AddrPort]netip.AddrPort
  15. packetConnMap map[uint16]uint16
  16. }
  17. func newLoopBackDetector(networkManager adapter.NetworkManager) *loopBackDetector {
  18. return &loopBackDetector{
  19. networkManager: networkManager,
  20. connMap: make(map[netip.AddrPort]netip.AddrPort),
  21. packetConnMap: make(map[uint16]uint16),
  22. }
  23. }
  24. func (l *loopBackDetector) NewConn(conn net.Conn) net.Conn {
  25. source := M.AddrPortFromNet(conn.LocalAddr())
  26. if !source.IsValid() {
  27. return conn
  28. }
  29. if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn {
  30. if !source.Addr().IsLoopback() {
  31. _, err := l.networkManager.InterfaceFinder().ByAddr(source.Addr())
  32. if err != nil {
  33. return conn
  34. }
  35. }
  36. if !N.IsPublicAddr(source.Addr()) {
  37. return conn
  38. }
  39. l.packetConnAccess.Lock()
  40. l.packetConnMap[source.Port()] = M.AddrPortFromNet(conn.RemoteAddr()).Port()
  41. l.packetConnAccess.Unlock()
  42. return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connPort: source.Port()}
  43. } else {
  44. l.connAccess.Lock()
  45. l.connMap[source] = M.AddrPortFromNet(conn.RemoteAddr())
  46. l.connAccess.Unlock()
  47. return &loopBackDetectWrapper{Conn: conn, detector: l, connAddr: source}
  48. }
  49. }
  50. func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn, destination M.Socksaddr) N.NetPacketConn {
  51. source := M.AddrPortFromNet(conn.LocalAddr())
  52. if !source.IsValid() {
  53. return conn
  54. }
  55. if !source.Addr().IsLoopback() {
  56. _, err := l.networkManager.InterfaceFinder().ByAddr(source.Addr())
  57. if err != nil {
  58. return conn
  59. }
  60. }
  61. l.packetConnAccess.Lock()
  62. l.packetConnMap[source.Port()] = destination.AddrPort().Port()
  63. l.packetConnAccess.Unlock()
  64. return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connPort: source.Port()}
  65. }
  66. func (l *loopBackDetector) CheckConn(source netip.AddrPort, local netip.AddrPort) bool {
  67. l.connAccess.RLock()
  68. defer l.connAccess.RUnlock()
  69. destination, loaded := l.connMap[source]
  70. return loaded && destination != local
  71. }
  72. func (l *loopBackDetector) CheckPacketConn(source netip.AddrPort, local netip.AddrPort) bool {
  73. if !source.IsValid() {
  74. return false
  75. }
  76. if !source.Addr().IsLoopback() {
  77. _, err := l.networkManager.InterfaceFinder().ByAddr(source.Addr())
  78. if err != nil {
  79. return false
  80. }
  81. }
  82. if N.IsPublicAddr(source.Addr()) {
  83. return false
  84. }
  85. l.packetConnAccess.RLock()
  86. defer l.packetConnAccess.RUnlock()
  87. destinationPort, loaded := l.packetConnMap[source.Port()]
  88. return loaded && destinationPort != local.Port()
  89. }
  90. type loopBackDetectWrapper struct {
  91. net.Conn
  92. detector *loopBackDetector
  93. connAddr netip.AddrPort
  94. closeOnce sync.Once
  95. }
  96. func (w *loopBackDetectWrapper) Close() error {
  97. w.closeOnce.Do(func() {
  98. w.detector.connAccess.Lock()
  99. delete(w.detector.connMap, w.connAddr)
  100. w.detector.connAccess.Unlock()
  101. })
  102. return w.Conn.Close()
  103. }
  104. func (w *loopBackDetectWrapper) ReaderReplaceable() bool {
  105. return true
  106. }
  107. func (w *loopBackDetectWrapper) WriterReplaceable() bool {
  108. return true
  109. }
  110. func (w *loopBackDetectWrapper) Upstream() any {
  111. return w.Conn
  112. }
  113. type loopBackDetectPacketWrapper struct {
  114. N.NetPacketConn
  115. detector *loopBackDetector
  116. connPort uint16
  117. closeOnce sync.Once
  118. }
  119. func (w *loopBackDetectPacketWrapper) Close() error {
  120. w.closeOnce.Do(func() {
  121. w.detector.packetConnAccess.Lock()
  122. delete(w.detector.packetConnMap, w.connPort)
  123. w.detector.packetConnAccess.Unlock()
  124. })
  125. return w.NetPacketConn.Close()
  126. }
  127. func (w *loopBackDetectPacketWrapper) ReaderReplaceable() bool {
  128. return true
  129. }
  130. func (w *loopBackDetectPacketWrapper) WriterReplaceable() bool {
  131. return true
  132. }
  133. func (w *loopBackDetectPacketWrapper) Upstream() any {
  134. return w.NetPacketConn
  135. }
  136. type abstractUDPConn interface {
  137. net.Conn
  138. net.PacketConn
  139. }
  140. type loopBackDetectUDPWrapper struct {
  141. abstractUDPConn
  142. detector *loopBackDetector
  143. connPort uint16
  144. closeOnce sync.Once
  145. }
  146. func (w *loopBackDetectUDPWrapper) Close() error {
  147. w.closeOnce.Do(func() {
  148. w.detector.packetConnAccess.Lock()
  149. delete(w.detector.packetConnMap, w.connPort)
  150. w.detector.packetConnAccess.Unlock()
  151. })
  152. return w.abstractUDPConn.Close()
  153. }
  154. func (w *loopBackDetectUDPWrapper) ReaderReplaceable() bool {
  155. return true
  156. }
  157. func (w *loopBackDetectUDPWrapper) WriterReplaceable() bool {
  158. return true
  159. }
  160. func (w *loopBackDetectUDPWrapper) Upstream() any {
  161. return w.abstractUDPConn
  162. }