default_parallel_network.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. package dialer
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "time"
  7. C "github.com/sagernet/sing-box/constant"
  8. "github.com/sagernet/sing/common"
  9. E "github.com/sagernet/sing/common/exceptions"
  10. M "github.com/sagernet/sing/common/metadata"
  11. N "github.com/sagernet/sing/common/network"
  12. )
  13. func DialSerialNetwork(ctx context.Context, dialer ParallelInterfaceDialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) {
  14. if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel {
  15. return parallelDialer.DialParallelNetwork(ctx, network, destination, destinationAddresses, strategy, fallbackDelay)
  16. }
  17. var errors []error
  18. for _, address := range destinationAddresses {
  19. conn, err := dialer.DialParallelInterface(ctx, network, M.SocksaddrFrom(address, destination.Port), strategy, fallbackDelay)
  20. if err == nil {
  21. return conn, nil
  22. }
  23. errors = append(errors, err)
  24. }
  25. return nil, E.Errors(errors...)
  26. }
  27. func DialParallelNetwork(ctx context.Context, dialer ParallelInterfaceDialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, preferIPv6 bool, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) {
  28. if fallbackDelay == 0 {
  29. fallbackDelay = N.DefaultFallbackDelay
  30. }
  31. returned := make(chan struct{})
  32. defer close(returned)
  33. addresses4 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
  34. return address.Is4() || address.Is4In6()
  35. })
  36. addresses6 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
  37. return address.Is6() && !address.Is4In6()
  38. })
  39. if len(addresses4) == 0 || len(addresses6) == 0 {
  40. return DialSerialNetwork(ctx, dialer, network, destination, destinationAddresses, strategy, fallbackDelay)
  41. }
  42. var primaries, fallbacks []netip.Addr
  43. if preferIPv6 {
  44. primaries = addresses6
  45. fallbacks = addresses4
  46. } else {
  47. primaries = addresses4
  48. fallbacks = addresses6
  49. }
  50. type dialResult struct {
  51. net.Conn
  52. error
  53. primary bool
  54. done bool
  55. }
  56. results := make(chan dialResult) // unbuffered
  57. startRacer := func(ctx context.Context, primary bool) {
  58. ras := primaries
  59. if !primary {
  60. ras = fallbacks
  61. }
  62. c, err := DialSerialNetwork(ctx, dialer, network, destination, ras, strategy, fallbackDelay)
  63. select {
  64. case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
  65. case <-returned:
  66. if c != nil {
  67. c.Close()
  68. }
  69. }
  70. }
  71. var primary, fallback dialResult
  72. primaryCtx, primaryCancel := context.WithCancel(ctx)
  73. defer primaryCancel()
  74. go startRacer(primaryCtx, true)
  75. fallbackTimer := time.NewTimer(fallbackDelay)
  76. defer fallbackTimer.Stop()
  77. for {
  78. select {
  79. case <-fallbackTimer.C:
  80. fallbackCtx, fallbackCancel := context.WithCancel(ctx)
  81. defer fallbackCancel()
  82. go startRacer(fallbackCtx, false)
  83. case res := <-results:
  84. if res.error == nil {
  85. return res.Conn, nil
  86. }
  87. if res.primary {
  88. primary = res
  89. } else {
  90. fallback = res
  91. }
  92. if primary.done && fallback.done {
  93. return nil, primary.error
  94. }
  95. if res.primary && fallbackTimer.Stop() {
  96. fallbackTimer.Reset(0)
  97. }
  98. }
  99. }
  100. }
  101. func ListenSerialNetworkPacket(ctx context.Context, dialer ParallelInterfaceDialer, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) {
  102. if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel {
  103. return parallelDialer.ListenSerialNetworkPacket(ctx, destination, destinationAddresses, strategy, fallbackDelay)
  104. }
  105. var errors []error
  106. for _, address := range destinationAddresses {
  107. conn, err := dialer.ListenSerialInterfacePacket(ctx, M.SocksaddrFrom(address, destination.Port), strategy, fallbackDelay)
  108. if err == nil {
  109. return conn, address, nil
  110. }
  111. errors = append(errors, err)
  112. }
  113. return nil, netip.Addr{}, E.Errors(errors...)
  114. }