parallel.go 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package dialer
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "time"
  7. "github.com/sagernet/sing/common"
  8. M "github.com/sagernet/sing/common/metadata"
  9. N "github.com/sagernet/sing/common/network"
  10. C "github.com/sagernet/sing-box/constant"
  11. )
  12. func DialParallel(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.DomainStrategy, fallbackDelay time.Duration) (net.Conn, error) {
  13. // kanged form net.Dial
  14. returned := make(chan struct{})
  15. defer close(returned)
  16. addresses4 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
  17. return address.Is4() || address.Is4In6()
  18. })
  19. addresses6 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
  20. return address.Is6() && !address.Is4In6()
  21. })
  22. if len(addresses4) == 0 || len(addresses6) == 0 {
  23. return DialSerial(ctx, dialer, network, destination, destinationAddresses)
  24. }
  25. var primaries, fallbacks []netip.Addr
  26. switch strategy {
  27. case C.DomainStrategyPreferIPv6:
  28. primaries = addresses6
  29. fallbacks = addresses4
  30. default:
  31. primaries = addresses4
  32. fallbacks = addresses6
  33. }
  34. type dialResult struct {
  35. net.Conn
  36. error
  37. primary bool
  38. done bool
  39. }
  40. results := make(chan dialResult) // unbuffered
  41. startRacer := func(ctx context.Context, primary bool) {
  42. ras := primaries
  43. if !primary {
  44. ras = fallbacks
  45. }
  46. c, err := DialSerial(ctx, dialer, network, destination, ras)
  47. select {
  48. case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
  49. case <-returned:
  50. if c != nil {
  51. c.Close()
  52. }
  53. }
  54. }
  55. var primary, fallback dialResult
  56. primaryCtx, primaryCancel := context.WithCancel(ctx)
  57. defer primaryCancel()
  58. go startRacer(primaryCtx, true)
  59. fallbackTimer := time.NewTimer(fallbackDelay)
  60. defer fallbackTimer.Stop()
  61. for {
  62. select {
  63. case <-fallbackTimer.C:
  64. fallbackCtx, fallbackCancel := context.WithCancel(ctx)
  65. defer fallbackCancel()
  66. go startRacer(fallbackCtx, false)
  67. case res := <-results:
  68. if res.error == nil {
  69. return res.Conn, nil
  70. }
  71. if res.primary {
  72. primary = res
  73. } else {
  74. fallback = res
  75. }
  76. if primary.done && fallback.done {
  77. return nil, primary.error
  78. }
  79. if res.primary && fallbackTimer.Stop() {
  80. fallbackTimer.Reset(0)
  81. }
  82. }
  83. }
  84. }