default.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. package dialer
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "time"
  7. "github.com/sagernet/sing-box/adapter"
  8. "github.com/sagernet/sing-box/common/conntrack"
  9. C "github.com/sagernet/sing-box/constant"
  10. "github.com/sagernet/sing-box/experimental/libbox/platform"
  11. "github.com/sagernet/sing-box/option"
  12. "github.com/sagernet/sing/common"
  13. "github.com/sagernet/sing/common/atomic"
  14. "github.com/sagernet/sing/common/control"
  15. E "github.com/sagernet/sing/common/exceptions"
  16. M "github.com/sagernet/sing/common/metadata"
  17. N "github.com/sagernet/sing/common/network"
  18. "github.com/sagernet/sing/service"
  19. )
  20. var (
  21. _ ParallelInterfaceDialer = (*DefaultDialer)(nil)
  22. _ WireGuardListener = (*DefaultDialer)(nil)
  23. )
  24. type DefaultDialer struct {
  25. dialer4 tcpDialer
  26. dialer6 tcpDialer
  27. udpDialer4 net.Dialer
  28. udpDialer6 net.Dialer
  29. udpListener net.ListenConfig
  30. udpAddr4 string
  31. udpAddr6 string
  32. isWireGuardListener bool
  33. networkManager adapter.NetworkManager
  34. networkStrategy *C.NetworkStrategy
  35. networkType []C.InterfaceType
  36. fallbackNetworkType []C.InterfaceType
  37. networkFallbackDelay time.Duration
  38. networkLastFallback atomic.TypedValue[time.Time]
  39. }
  40. func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDialer, error) {
  41. networkManager := service.FromContext[adapter.NetworkManager](ctx)
  42. platformInterface := service.FromContext[platform.Interface](ctx)
  43. var (
  44. dialer net.Dialer
  45. listener net.ListenConfig
  46. interfaceFinder control.InterfaceFinder
  47. networkStrategy *C.NetworkStrategy
  48. networkType []C.InterfaceType
  49. fallbackNetworkType []C.InterfaceType
  50. networkFallbackDelay time.Duration
  51. )
  52. if networkManager != nil {
  53. interfaceFinder = networkManager.InterfaceFinder()
  54. } else {
  55. interfaceFinder = control.NewDefaultInterfaceFinder()
  56. }
  57. if options.BindInterface != "" {
  58. bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1)
  59. dialer.Control = control.Append(dialer.Control, bindFunc)
  60. listener.Control = control.Append(listener.Control, bindFunc)
  61. }
  62. if options.RoutingMark > 0 {
  63. dialer.Control = control.Append(dialer.Control, control.RoutingMark(uint32(options.RoutingMark)))
  64. listener.Control = control.Append(listener.Control, control.RoutingMark(uint32(options.RoutingMark)))
  65. }
  66. if networkManager != nil {
  67. autoRedirectOutputMark := networkManager.AutoRedirectOutputMark()
  68. if autoRedirectOutputMark > 0 {
  69. if options.RoutingMark > 0 {
  70. return nil, E.New("`routing_mark` is conflict with `tun.auto_redirect` with `tun.route_[_exclude]_address_set")
  71. }
  72. dialer.Control = control.Append(dialer.Control, control.RoutingMark(autoRedirectOutputMark))
  73. listener.Control = control.Append(listener.Control, control.RoutingMark(autoRedirectOutputMark))
  74. }
  75. }
  76. disableDefaultBind := options.BindInterface != "" || options.Inet4BindAddress != nil || options.Inet6BindAddress != nil
  77. if disableDefaultBind || options.TCPFastOpen {
  78. if options.NetworkStrategy != nil || len(options.NetworkType) > 0 && options.FallbackNetworkType == nil && options.FallbackDelay == 0 {
  79. return nil, E.New("`network_strategy` is conflict with `bind_interface`, `inet4_bind_address`, `inet6_bind_address` and `tcp_fast_open`")
  80. }
  81. }
  82. if networkManager != nil {
  83. defaultOptions := networkManager.DefaultOptions()
  84. if !disableDefaultBind {
  85. if defaultOptions.BindInterface != "" {
  86. bindFunc := control.BindToInterface(networkManager.InterfaceFinder(), defaultOptions.BindInterface, -1)
  87. dialer.Control = control.Append(dialer.Control, bindFunc)
  88. listener.Control = control.Append(listener.Control, bindFunc)
  89. } else if networkManager.AutoDetectInterface() {
  90. if platformInterface != nil {
  91. networkStrategy = (*C.NetworkStrategy)(options.NetworkStrategy)
  92. if networkStrategy == nil {
  93. networkStrategy = common.Ptr(C.NetworkStrategyDefault)
  94. }
  95. networkType = common.Map(options.NetworkType, option.InterfaceType.Build)
  96. fallbackNetworkType = common.Map(options.FallbackNetworkType, option.InterfaceType.Build)
  97. if networkStrategy == nil && len(networkType) == 0 && len(fallbackNetworkType) == 0 {
  98. networkStrategy = defaultOptions.NetworkStrategy
  99. networkType = defaultOptions.NetworkType
  100. fallbackNetworkType = defaultOptions.FallbackNetworkType
  101. }
  102. networkFallbackDelay = time.Duration(options.FallbackDelay)
  103. if networkFallbackDelay == 0 && defaultOptions.FallbackDelay != 0 {
  104. networkFallbackDelay = defaultOptions.FallbackDelay
  105. }
  106. bindFunc := networkManager.ProtectFunc()
  107. dialer.Control = control.Append(dialer.Control, bindFunc)
  108. listener.Control = control.Append(listener.Control, bindFunc)
  109. } else {
  110. bindFunc := networkManager.AutoDetectInterfaceFunc()
  111. dialer.Control = control.Append(dialer.Control, bindFunc)
  112. listener.Control = control.Append(listener.Control, bindFunc)
  113. }
  114. }
  115. }
  116. if options.RoutingMark == 0 && defaultOptions.RoutingMark != 0 {
  117. dialer.Control = control.Append(dialer.Control, control.RoutingMark(defaultOptions.RoutingMark))
  118. listener.Control = control.Append(listener.Control, control.RoutingMark(defaultOptions.RoutingMark))
  119. }
  120. }
  121. if options.ReuseAddr {
  122. listener.Control = control.Append(listener.Control, control.ReuseAddr())
  123. }
  124. if options.ProtectPath != "" {
  125. dialer.Control = control.Append(dialer.Control, control.ProtectPath(options.ProtectPath))
  126. listener.Control = control.Append(listener.Control, control.ProtectPath(options.ProtectPath))
  127. }
  128. if options.ConnectTimeout != 0 {
  129. dialer.Timeout = time.Duration(options.ConnectTimeout)
  130. } else {
  131. dialer.Timeout = C.TCPConnectTimeout
  132. }
  133. // TODO: Add an option to customize the keep alive period
  134. dialer.KeepAlive = C.TCPKeepAliveInitial
  135. dialer.Control = control.Append(dialer.Control, control.SetKeepAlivePeriod(C.TCPKeepAliveInitial, C.TCPKeepAliveInterval))
  136. var udpFragment bool
  137. if options.UDPFragment != nil {
  138. udpFragment = *options.UDPFragment
  139. } else {
  140. udpFragment = options.UDPFragmentDefault
  141. }
  142. if !udpFragment {
  143. dialer.Control = control.Append(dialer.Control, control.DisableUDPFragment())
  144. listener.Control = control.Append(listener.Control, control.DisableUDPFragment())
  145. }
  146. var (
  147. dialer4 = dialer
  148. udpDialer4 = dialer
  149. udpAddr4 string
  150. )
  151. if options.Inet4BindAddress != nil {
  152. bindAddr := options.Inet4BindAddress.Build(netip.IPv4Unspecified())
  153. dialer4.LocalAddr = &net.TCPAddr{IP: bindAddr.AsSlice()}
  154. udpDialer4.LocalAddr = &net.UDPAddr{IP: bindAddr.AsSlice()}
  155. udpAddr4 = M.SocksaddrFrom(bindAddr, 0).String()
  156. }
  157. var (
  158. dialer6 = dialer
  159. udpDialer6 = dialer
  160. udpAddr6 string
  161. )
  162. if options.Inet6BindAddress != nil {
  163. bindAddr := options.Inet6BindAddress.Build(netip.IPv6Unspecified())
  164. dialer6.LocalAddr = &net.TCPAddr{IP: bindAddr.AsSlice()}
  165. udpDialer6.LocalAddr = &net.UDPAddr{IP: bindAddr.AsSlice()}
  166. udpAddr6 = M.SocksaddrFrom(bindAddr, 0).String()
  167. }
  168. if options.TCPMultiPath {
  169. if !go121Available {
  170. return nil, E.New("MultiPath TCP requires go1.21, please recompile your binary.")
  171. }
  172. setMultiPathTCP(&dialer4)
  173. }
  174. if options.IsWireGuardListener {
  175. for _, controlFn := range WgControlFns {
  176. listener.Control = control.Append(listener.Control, controlFn)
  177. }
  178. }
  179. tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen)
  180. if err != nil {
  181. return nil, err
  182. }
  183. tcpDialer6, err := newTCPDialer(dialer6, options.TCPFastOpen)
  184. if err != nil {
  185. return nil, err
  186. }
  187. return &DefaultDialer{
  188. dialer4: tcpDialer4,
  189. dialer6: tcpDialer6,
  190. udpDialer4: udpDialer4,
  191. udpDialer6: udpDialer6,
  192. udpListener: listener,
  193. udpAddr4: udpAddr4,
  194. udpAddr6: udpAddr6,
  195. isWireGuardListener: options.IsWireGuardListener,
  196. networkManager: networkManager,
  197. networkStrategy: networkStrategy,
  198. networkType: networkType,
  199. fallbackNetworkType: fallbackNetworkType,
  200. networkFallbackDelay: networkFallbackDelay,
  201. }, nil
  202. }
  203. func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) {
  204. if !address.IsValid() {
  205. return nil, E.New("invalid address")
  206. }
  207. if d.networkStrategy == nil {
  208. switch N.NetworkName(network) {
  209. case N.NetworkUDP:
  210. if !address.IsIPv6() {
  211. return trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
  212. } else {
  213. return trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
  214. }
  215. }
  216. if !address.IsIPv6() {
  217. return trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
  218. } else {
  219. return trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
  220. }
  221. } else {
  222. return d.DialParallelInterface(ctx, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
  223. }
  224. }
  225. func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network string, address M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
  226. if strategy == nil {
  227. strategy = d.networkStrategy
  228. }
  229. if strategy == nil {
  230. return d.DialContext(ctx, network, address)
  231. }
  232. if len(interfaceType) == 0 {
  233. interfaceType = d.networkType
  234. }
  235. if len(fallbackInterfaceType) == 0 {
  236. fallbackInterfaceType = d.fallbackNetworkType
  237. }
  238. if fallbackDelay == 0 {
  239. fallbackDelay = d.networkFallbackDelay
  240. }
  241. var dialer net.Dialer
  242. if N.NetworkName(network) == N.NetworkTCP {
  243. dialer = dialerFromTCPDialer(d.dialer4)
  244. } else {
  245. dialer = d.udpDialer4
  246. }
  247. fastFallback := time.Now().Sub(d.networkLastFallback.Load()) < C.TCPTimeout
  248. var (
  249. conn net.Conn
  250. isPrimary bool
  251. err error
  252. )
  253. if !fastFallback {
  254. conn, isPrimary, err = d.dialParallelInterface(ctx, dialer, network, address.String(), *strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  255. } else {
  256. conn, isPrimary, err = d.dialParallelInterfaceFastFallback(ctx, dialer, network, address.String(), *strategy, interfaceType, fallbackInterfaceType, fallbackDelay, d.networkLastFallback.Store)
  257. }
  258. if err != nil {
  259. return nil, err
  260. }
  261. if !fastFallback && !isPrimary {
  262. d.networkLastFallback.Store(time.Now())
  263. }
  264. return trackConn(conn, nil)
  265. }
  266. func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  267. if d.networkStrategy == nil {
  268. if destination.IsIPv6() {
  269. return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
  270. } else if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
  271. return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
  272. } else {
  273. return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
  274. }
  275. } else {
  276. return d.ListenSerialInterfacePacket(ctx, destination, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)
  277. }
  278. }
  279. func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) {
  280. if strategy == nil {
  281. strategy = d.networkStrategy
  282. }
  283. if strategy == nil {
  284. return d.ListenPacket(ctx, destination)
  285. }
  286. if len(interfaceType) == 0 {
  287. interfaceType = d.networkType
  288. }
  289. if len(fallbackInterfaceType) == 0 {
  290. fallbackInterfaceType = d.fallbackNetworkType
  291. }
  292. if fallbackDelay == 0 {
  293. fallbackDelay = d.networkFallbackDelay
  294. }
  295. network := N.NetworkUDP
  296. if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
  297. network += "4"
  298. }
  299. return trackPacketConn(d.listenSerialInterfacePacket(ctx, d.udpListener, network, "", *strategy, interfaceType, fallbackInterfaceType, fallbackDelay))
  300. }
  301. func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
  302. return d.udpListener.ListenPacket(context.Background(), network, address)
  303. }
  304. func trackConn(conn net.Conn, err error) (net.Conn, error) {
  305. if !conntrack.Enabled || err != nil {
  306. return conn, err
  307. }
  308. return conntrack.NewConn(conn)
  309. }
  310. func trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
  311. if !conntrack.Enabled || err != nil {
  312. return conn, err
  313. }
  314. return conntrack.NewPacketConn(conn)
  315. }