1
0

default_parallel_network.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. package dialer
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "time"
  7. C "github.com/sagernet/sing-box/constant"
  8. "github.com/sagernet/sing/common"
  9. E "github.com/sagernet/sing/common/exceptions"
  10. M "github.com/sagernet/sing/common/metadata"
  11. N "github.com/sagernet/sing/common/network"
  12. )
  13. func DialSerialNetwork(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
  14. if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel {
  15. return parallelDialer.DialParallelNetwork(ctx, network, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  16. }
  17. var errors []error
  18. if parallelDialer, isParallel := dialer.(ParallelInterfaceDialer); isParallel {
  19. for _, address := range destinationAddresses {
  20. conn, err := parallelDialer.DialParallelInterface(ctx, network, M.SocksaddrFrom(address, destination.Port), strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  21. if err == nil {
  22. return conn, nil
  23. }
  24. errors = append(errors, err)
  25. }
  26. } else {
  27. for _, address := range destinationAddresses {
  28. conn, err := dialer.DialContext(ctx, network, M.SocksaddrFrom(address, destination.Port))
  29. if err == nil {
  30. return conn, nil
  31. }
  32. errors = append(errors, err)
  33. }
  34. }
  35. return nil, E.Errors(errors...)
  36. }
  37. func DialParallelNetwork(ctx context.Context, dialer ParallelInterfaceDialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, preferIPv6 bool, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
  38. if fallbackDelay == 0 {
  39. fallbackDelay = N.DefaultFallbackDelay
  40. }
  41. returned := make(chan struct{})
  42. defer close(returned)
  43. addresses4 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
  44. return address.Is4() || address.Is4In6()
  45. })
  46. addresses6 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
  47. return address.Is6() && !address.Is4In6()
  48. })
  49. if len(addresses4) == 0 || len(addresses6) == 0 {
  50. return DialSerialNetwork(ctx, dialer, network, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  51. }
  52. var primaries, fallbacks []netip.Addr
  53. if preferIPv6 {
  54. primaries = addresses6
  55. fallbacks = addresses4
  56. } else {
  57. primaries = addresses4
  58. fallbacks = addresses6
  59. }
  60. type dialResult struct {
  61. net.Conn
  62. error
  63. primary bool
  64. done bool
  65. }
  66. results := make(chan dialResult) // unbuffered
  67. startRacer := func(ctx context.Context, primary bool) {
  68. ras := primaries
  69. if !primary {
  70. ras = fallbacks
  71. }
  72. c, err := DialSerialNetwork(ctx, dialer, network, destination, ras, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  73. select {
  74. case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
  75. case <-returned:
  76. if c != nil {
  77. c.Close()
  78. }
  79. }
  80. }
  81. var primary, fallback dialResult
  82. primaryCtx, primaryCancel := context.WithCancel(ctx)
  83. defer primaryCancel()
  84. go startRacer(primaryCtx, true)
  85. fallbackTimer := time.NewTimer(fallbackDelay)
  86. defer fallbackTimer.Stop()
  87. for {
  88. select {
  89. case <-fallbackTimer.C:
  90. fallbackCtx, fallbackCancel := context.WithCancel(ctx)
  91. defer fallbackCancel()
  92. go startRacer(fallbackCtx, false)
  93. case res := <-results:
  94. if res.error == nil {
  95. return res.Conn, nil
  96. }
  97. if res.primary {
  98. primary = res
  99. } else {
  100. fallback = res
  101. }
  102. if primary.done && fallback.done {
  103. return nil, primary.error
  104. }
  105. if res.primary && fallbackTimer.Stop() {
  106. fallbackTimer.Reset(0)
  107. }
  108. }
  109. }
  110. }
  111. func ListenSerialNetworkPacket(ctx context.Context, dialer N.Dialer, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) {
  112. if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel {
  113. return parallelDialer.ListenSerialNetworkPacket(ctx, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  114. }
  115. var errors []error
  116. if parallelDialer, isParallel := dialer.(ParallelInterfaceDialer); isParallel {
  117. for _, address := range destinationAddresses {
  118. conn, err := parallelDialer.ListenSerialInterfacePacket(ctx, M.SocksaddrFrom(address, destination.Port), strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  119. if err == nil {
  120. return conn, address, nil
  121. }
  122. errors = append(errors, err)
  123. }
  124. } else {
  125. for _, address := range destinationAddresses {
  126. conn, err := dialer.ListenPacket(ctx, M.SocksaddrFrom(address, destination.Port))
  127. if err == nil {
  128. return conn, address, nil
  129. }
  130. errors = append(errors, err)
  131. }
  132. }
  133. return nil, netip.Addr{}, E.Errors(errors...)
  134. }