default.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. package dialer
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "time"
  7. "github.com/sagernet/sing-box/adapter"
  8. "github.com/sagernet/sing-box/common/conntrack"
  9. C "github.com/sagernet/sing-box/constant"
  10. "github.com/sagernet/sing-box/option"
  11. "github.com/sagernet/sing/common"
  12. "github.com/sagernet/sing/common/atomic"
  13. "github.com/sagernet/sing/common/control"
  14. E "github.com/sagernet/sing/common/exceptions"
  15. M "github.com/sagernet/sing/common/metadata"
  16. N "github.com/sagernet/sing/common/network"
  17. )
  18. var (
  19. _ ParallelInterfaceDialer = (*DefaultDialer)(nil)
  20. _ WireGuardListener = (*DefaultDialer)(nil)
  21. )
  22. type DefaultDialer struct {
  23. dialer4 tcpDialer
  24. dialer6 tcpDialer
  25. udpDialer4 net.Dialer
  26. udpDialer6 net.Dialer
  27. udpListener net.ListenConfig
  28. udpAddr4 string
  29. udpAddr6 string
  30. isWireGuardListener bool
  31. networkManager adapter.NetworkManager
  32. networkStrategy C.NetworkStrategy
  33. networkType []C.InterfaceType
  34. fallbackNetworkType []C.InterfaceType
  35. networkFallbackDelay time.Duration
  36. networkLastFallback atomic.TypedValue[time.Time]
  37. }
  38. func NewDefault(networkManager adapter.NetworkManager, options option.DialerOptions) (*DefaultDialer, error) {
  39. var (
  40. dialer net.Dialer
  41. listener net.ListenConfig
  42. interfaceFinder control.InterfaceFinder
  43. networkStrategy C.NetworkStrategy
  44. networkType []C.InterfaceType
  45. fallbackNetworkType []C.InterfaceType
  46. networkFallbackDelay time.Duration
  47. )
  48. if networkManager != nil {
  49. interfaceFinder = networkManager.InterfaceFinder()
  50. } else {
  51. interfaceFinder = control.NewDefaultInterfaceFinder()
  52. }
  53. if options.BindInterface != "" {
  54. bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1)
  55. dialer.Control = control.Append(dialer.Control, bindFunc)
  56. listener.Control = control.Append(listener.Control, bindFunc)
  57. }
  58. if options.RoutingMark > 0 {
  59. dialer.Control = control.Append(dialer.Control, control.RoutingMark(uint32(options.RoutingMark)))
  60. listener.Control = control.Append(listener.Control, control.RoutingMark(uint32(options.RoutingMark)))
  61. }
  62. if networkManager != nil {
  63. autoRedirectOutputMark := networkManager.AutoRedirectOutputMark()
  64. if autoRedirectOutputMark > 0 {
  65. if options.RoutingMark > 0 {
  66. return nil, E.New("`routing_mark` is conflict with `tun.auto_redirect` with `tun.route_[_exclude]_address_set")
  67. }
  68. dialer.Control = control.Append(dialer.Control, control.RoutingMark(autoRedirectOutputMark))
  69. listener.Control = control.Append(listener.Control, control.RoutingMark(autoRedirectOutputMark))
  70. }
  71. }
  72. if C.NetworkStrategy(options.NetworkStrategy) != C.NetworkStrategyDefault {
  73. if options.BindInterface != "" || options.Inet4BindAddress != nil || options.Inet6BindAddress != nil {
  74. return nil, E.New("`network_strategy` is conflict with `bind_interface`, `inet4_bind_address` and `inet6_bind_address`")
  75. }
  76. networkStrategy = C.NetworkStrategy(options.NetworkStrategy)
  77. networkType = common.Map(options.NetworkType, option.InterfaceType.Build)
  78. fallbackNetworkType = common.Map(options.FallbackNetworkType, option.InterfaceType.Build)
  79. networkFallbackDelay = time.Duration(options.NetworkFallbackDelay)
  80. if networkManager == nil || !networkManager.AutoDetectInterface() {
  81. return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`")
  82. }
  83. }
  84. if networkManager != nil && options.BindInterface == "" && options.Inet4BindAddress == nil && options.Inet6BindAddress == nil {
  85. defaultOptions := networkManager.DefaultOptions()
  86. if options.BindInterface == "" {
  87. if defaultOptions.BindInterface != "" {
  88. bindFunc := control.BindToInterface(networkManager.InterfaceFinder(), defaultOptions.BindInterface, -1)
  89. dialer.Control = control.Append(dialer.Control, bindFunc)
  90. listener.Control = control.Append(listener.Control, bindFunc)
  91. } else if networkManager.AutoDetectInterface() {
  92. if defaultOptions.NetworkStrategy != C.NetworkStrategyDefault && C.NetworkStrategy(options.NetworkStrategy) == C.NetworkStrategyDefault {
  93. networkStrategy = defaultOptions.NetworkStrategy
  94. networkType = defaultOptions.NetworkType
  95. fallbackNetworkType = defaultOptions.FallbackNetworkType
  96. networkFallbackDelay = defaultOptions.FallbackDelay
  97. bindFunc := networkManager.ProtectFunc()
  98. dialer.Control = control.Append(dialer.Control, bindFunc)
  99. listener.Control = control.Append(listener.Control, bindFunc)
  100. } else {
  101. bindFunc := networkManager.AutoDetectInterfaceFunc()
  102. dialer.Control = control.Append(dialer.Control, bindFunc)
  103. listener.Control = control.Append(listener.Control, bindFunc)
  104. }
  105. }
  106. }
  107. if options.RoutingMark == 0 && defaultOptions.RoutingMark != 0 {
  108. dialer.Control = control.Append(dialer.Control, control.RoutingMark(defaultOptions.RoutingMark))
  109. listener.Control = control.Append(listener.Control, control.RoutingMark(defaultOptions.RoutingMark))
  110. }
  111. }
  112. if options.ReuseAddr {
  113. listener.Control = control.Append(listener.Control, control.ReuseAddr())
  114. }
  115. if options.ProtectPath != "" {
  116. dialer.Control = control.Append(dialer.Control, control.ProtectPath(options.ProtectPath))
  117. listener.Control = control.Append(listener.Control, control.ProtectPath(options.ProtectPath))
  118. }
  119. if options.ConnectTimeout != 0 {
  120. dialer.Timeout = time.Duration(options.ConnectTimeout)
  121. } else {
  122. dialer.Timeout = C.TCPConnectTimeout
  123. }
  124. // TODO: Add an option to customize the keep alive period
  125. dialer.KeepAlive = C.TCPKeepAliveInitial
  126. dialer.Control = control.Append(dialer.Control, control.SetKeepAlivePeriod(C.TCPKeepAliveInitial, C.TCPKeepAliveInterval))
  127. var udpFragment bool
  128. if options.UDPFragment != nil {
  129. udpFragment = *options.UDPFragment
  130. } else {
  131. udpFragment = options.UDPFragmentDefault
  132. }
  133. if !udpFragment {
  134. dialer.Control = control.Append(dialer.Control, control.DisableUDPFragment())
  135. listener.Control = control.Append(listener.Control, control.DisableUDPFragment())
  136. }
  137. var (
  138. dialer4 = dialer
  139. udpDialer4 = dialer
  140. udpAddr4 string
  141. )
  142. if options.Inet4BindAddress != nil {
  143. bindAddr := options.Inet4BindAddress.Build(netip.IPv4Unspecified())
  144. dialer4.LocalAddr = &net.TCPAddr{IP: bindAddr.AsSlice()}
  145. udpDialer4.LocalAddr = &net.UDPAddr{IP: bindAddr.AsSlice()}
  146. udpAddr4 = M.SocksaddrFrom(bindAddr, 0).String()
  147. }
  148. var (
  149. dialer6 = dialer
  150. udpDialer6 = dialer
  151. udpAddr6 string
  152. )
  153. if options.Inet6BindAddress != nil {
  154. bindAddr := options.Inet6BindAddress.Build(netip.IPv6Unspecified())
  155. dialer6.LocalAddr = &net.TCPAddr{IP: bindAddr.AsSlice()}
  156. udpDialer6.LocalAddr = &net.UDPAddr{IP: bindAddr.AsSlice()}
  157. udpAddr6 = M.SocksaddrFrom(bindAddr, 0).String()
  158. }
  159. if options.TCPMultiPath {
  160. if !go121Available {
  161. return nil, E.New("MultiPath TCP requires go1.21, please recompile your binary.")
  162. }
  163. setMultiPathTCP(&dialer4)
  164. }
  165. if options.IsWireGuardListener {
  166. for _, controlFn := range WgControlFns {
  167. listener.Control = control.Append(listener.Control, controlFn)
  168. }
  169. }
  170. if networkStrategy != C.NetworkStrategyDefault && options.TCPFastOpen {
  171. return nil, E.New("`tcp_fast_open` is conflict with `network_strategy` or `route.default_network_strategy`")
  172. }
  173. tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen)
  174. if err != nil {
  175. return nil, err
  176. }
  177. tcpDialer6, err := newTCPDialer(dialer6, options.TCPFastOpen)
  178. if err != nil {
  179. return nil, err
  180. }
  181. return &DefaultDialer{
  182. dialer4: tcpDialer4,
  183. dialer6: tcpDialer6,
  184. udpDialer4: udpDialer4,
  185. udpDialer6: udpDialer6,
  186. udpListener: listener,
  187. udpAddr4: udpAddr4,
  188. udpAddr6: udpAddr6,
  189. isWireGuardListener: options.IsWireGuardListener,
  190. networkManager: networkManager,
  191. networkStrategy: networkStrategy,
  192. networkType: networkType,
  193. fallbackNetworkType: fallbackNetworkType,
  194. networkFallbackDelay: networkFallbackDelay,
  195. }, nil
  196. }
  197. func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) {
  198. if !address.IsValid() {
  199. return nil, E.New("invalid address")
  200. }
  201. if d.networkStrategy == C.NetworkStrategyDefault {
  202. switch N.NetworkName(network) {
  203. case N.NetworkUDP:
  204. if !address.IsIPv6() {
  205. return trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
  206. } else {
  207. return trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
  208. }
  209. }
  210. if !address.IsIPv6() {
  211. return trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
  212. } else {
  213. return trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
  214. }
  215. } else {
  216. return d.DialParallelInterface(ctx, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
  217. }
  218. }
  219. func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network string, address M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
  220. if strategy == C.NetworkStrategyDefault {
  221. return d.DialContext(ctx, network, address)
  222. }
  223. if !d.networkManager.AutoDetectInterface() {
  224. return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`")
  225. }
  226. var dialer net.Dialer
  227. if N.NetworkName(network) == N.NetworkTCP {
  228. dialer = dialerFromTCPDialer(d.dialer4)
  229. } else {
  230. dialer = d.udpDialer4
  231. }
  232. fastFallback := time.Now().Sub(d.networkLastFallback.Load()) < C.TCPTimeout
  233. var (
  234. conn net.Conn
  235. isPrimary bool
  236. err error
  237. )
  238. if !fastFallback {
  239. conn, isPrimary, err = d.dialParallelInterface(ctx, dialer, network, address.String(), strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  240. } else {
  241. conn, isPrimary, err = d.dialParallelInterfaceFastFallback(ctx, dialer, network, address.String(), strategy, interfaceType, fallbackInterfaceType, fallbackDelay, d.networkLastFallback.Store)
  242. }
  243. if err != nil {
  244. return nil, err
  245. }
  246. if !fastFallback && !isPrimary {
  247. d.networkLastFallback.Store(time.Now())
  248. }
  249. return trackConn(conn, nil)
  250. }
  251. func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  252. if d.networkStrategy == C.NetworkStrategyDefault {
  253. if destination.IsIPv6() {
  254. return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
  255. } else if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
  256. return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
  257. } else {
  258. return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
  259. }
  260. } else {
  261. return d.ListenSerialInterfacePacket(ctx, destination, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
  262. }
  263. }
  264. func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) {
  265. if strategy == C.NetworkStrategyDefault {
  266. return d.ListenPacket(ctx, destination)
  267. }
  268. if !d.networkManager.AutoDetectInterface() {
  269. return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`")
  270. }
  271. network := N.NetworkUDP
  272. if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
  273. network += "4"
  274. }
  275. return trackPacketConn(d.listenSerialInterfacePacket(ctx, d.udpListener, network, "", strategy, interfaceType, fallbackInterfaceType, fallbackDelay))
  276. }
  277. func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
  278. return d.udpListener.ListenPacket(context.Background(), network, address)
  279. }
  280. func trackConn(conn net.Conn, err error) (net.Conn, error) {
  281. if !conntrack.Enabled || err != nil {
  282. return conn, err
  283. }
  284. return conntrack.NewConn(conn)
  285. }
  286. func trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
  287. if !conntrack.Enabled || err != nil {
  288. return conn, err
  289. }
  290. return conntrack.NewPacketConn(conn)
  291. }