default_parallel_network.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. package dialer
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "time"
  7. C "github.com/sagernet/sing-box/constant"
  8. "github.com/sagernet/sing/common"
  9. E "github.com/sagernet/sing/common/exceptions"
  10. M "github.com/sagernet/sing/common/metadata"
  11. N "github.com/sagernet/sing/common/network"
  12. )
  13. func DialSerialNetwork(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
  14. if len(destinationAddresses) == 0 {
  15. if !destination.IsIP() {
  16. panic("invalid usage")
  17. }
  18. destinationAddresses = []netip.Addr{destination.Addr}
  19. }
  20. if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel {
  21. return parallelDialer.DialParallelNetwork(ctx, network, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  22. }
  23. var errors []error
  24. if parallelDialer, isParallel := dialer.(ParallelInterfaceDialer); isParallel {
  25. for _, address := range destinationAddresses {
  26. conn, err := parallelDialer.DialParallelInterface(ctx, network, M.SocksaddrFrom(address, destination.Port), strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  27. if err == nil {
  28. return conn, nil
  29. }
  30. errors = append(errors, err)
  31. }
  32. } else {
  33. for _, address := range destinationAddresses {
  34. conn, err := dialer.DialContext(ctx, network, M.SocksaddrFrom(address, destination.Port))
  35. if err == nil {
  36. return conn, nil
  37. }
  38. errors = append(errors, err)
  39. }
  40. }
  41. return nil, E.Errors(errors...)
  42. }
  43. func DialParallelNetwork(ctx context.Context, dialer ParallelInterfaceDialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, preferIPv6 bool, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
  44. if len(destinationAddresses) == 0 {
  45. if !destination.IsIP() {
  46. panic("invalid usage")
  47. }
  48. destinationAddresses = []netip.Addr{destination.Addr}
  49. }
  50. if fallbackDelay == 0 {
  51. fallbackDelay = N.DefaultFallbackDelay
  52. }
  53. returned := make(chan struct{})
  54. defer close(returned)
  55. addresses4 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
  56. return address.Is4() || address.Is4In6()
  57. })
  58. addresses6 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
  59. return address.Is6() && !address.Is4In6()
  60. })
  61. if len(addresses4) == 0 || len(addresses6) == 0 {
  62. return DialSerialNetwork(ctx, dialer, network, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  63. }
  64. var primaries, fallbacks []netip.Addr
  65. if preferIPv6 {
  66. primaries = addresses6
  67. fallbacks = addresses4
  68. } else {
  69. primaries = addresses4
  70. fallbacks = addresses6
  71. }
  72. type dialResult struct {
  73. net.Conn
  74. error
  75. primary bool
  76. done bool
  77. }
  78. results := make(chan dialResult) // unbuffered
  79. startRacer := func(ctx context.Context, primary bool) {
  80. ras := primaries
  81. if !primary {
  82. ras = fallbacks
  83. }
  84. c, err := DialSerialNetwork(ctx, dialer, network, destination, ras, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  85. select {
  86. case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
  87. case <-returned:
  88. if c != nil {
  89. c.Close()
  90. }
  91. }
  92. }
  93. var primary, fallback dialResult
  94. primaryCtx, primaryCancel := context.WithCancel(ctx)
  95. defer primaryCancel()
  96. go startRacer(primaryCtx, true)
  97. fallbackTimer := time.NewTimer(fallbackDelay)
  98. defer fallbackTimer.Stop()
  99. for {
  100. select {
  101. case <-fallbackTimer.C:
  102. fallbackCtx, fallbackCancel := context.WithCancel(ctx)
  103. defer fallbackCancel()
  104. go startRacer(fallbackCtx, false)
  105. case res := <-results:
  106. if res.error == nil {
  107. return res.Conn, nil
  108. }
  109. if res.primary {
  110. primary = res
  111. } else {
  112. fallback = res
  113. }
  114. if primary.done && fallback.done {
  115. return nil, primary.error
  116. }
  117. if res.primary && fallbackTimer.Stop() {
  118. fallbackTimer.Reset(0)
  119. }
  120. }
  121. }
  122. }
  123. func ListenSerialNetworkPacket(ctx context.Context, dialer N.Dialer, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) {
  124. if len(destinationAddresses) == 0 {
  125. if !destination.IsIP() {
  126. panic("invalid usage")
  127. }
  128. destinationAddresses = []netip.Addr{destination.Addr}
  129. }
  130. if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel {
  131. return parallelDialer.ListenSerialNetworkPacket(ctx, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  132. }
  133. var errors []error
  134. if parallelDialer, isParallel := dialer.(ParallelInterfaceDialer); isParallel {
  135. for _, address := range destinationAddresses {
  136. conn, err := parallelDialer.ListenSerialInterfacePacket(ctx, M.SocksaddrFrom(address, destination.Port), strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  137. if err == nil {
  138. return conn, address, nil
  139. }
  140. errors = append(errors, err)
  141. }
  142. } else {
  143. for _, address := range destinationAddresses {
  144. conn, err := dialer.ListenPacket(ctx, M.SocksaddrFrom(address, destination.Port))
  145. if err == nil {
  146. return conn, address, nil
  147. }
  148. errors = append(errors, err)
  149. }
  150. }
  151. return nil, netip.Addr{}, E.Errors(errors...)
  152. }