default_parallel_interface.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. package dialer
  2. import (
  3. "context"
  4. "net"
  5. "time"
  6. "github.com/sagernet/sing-box/adapter"
  7. C "github.com/sagernet/sing-box/constant"
  8. "github.com/sagernet/sing/common"
  9. "github.com/sagernet/sing/common/control"
  10. E "github.com/sagernet/sing/common/exceptions"
  11. N "github.com/sagernet/sing/common/network"
  12. )
  13. func (d *DefaultDialer) dialParallelInterface(ctx context.Context, dialer net.Dialer, network string, addr string, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, bool, error) {
  14. primaryInterfaces, fallbackInterfaces := selectInterfaces(d.networkManager, strategy, interfaceType, fallbackInterfaceType)
  15. if len(primaryInterfaces)+len(fallbackInterfaces) == 0 {
  16. return nil, false, E.New("no available network interface")
  17. }
  18. defaultInterface := d.networkManager.InterfaceMonitor().DefaultInterface()
  19. if fallbackDelay == 0 {
  20. fallbackDelay = N.DefaultFallbackDelay
  21. }
  22. returned := make(chan struct{})
  23. defer close(returned)
  24. type dialResult struct {
  25. net.Conn
  26. error
  27. primary bool
  28. }
  29. results := make(chan dialResult) // unbuffered
  30. startRacer := func(ctx context.Context, primary bool, iif adapter.NetworkInterface) {
  31. perNetDialer := dialer
  32. if defaultInterface == nil || iif.Index != defaultInterface.Index {
  33. perNetDialer.Control = control.Append(perNetDialer.Control, control.BindToInterface(nil, iif.Name, iif.Index))
  34. }
  35. conn, err := perNetDialer.DialContext(ctx, network, addr)
  36. if err != nil {
  37. select {
  38. case results <- dialResult{error: E.Cause(err, "dial ", iif.Name, " (", iif.Index, ")"), primary: primary}:
  39. case <-returned:
  40. }
  41. } else {
  42. select {
  43. case results <- dialResult{Conn: conn, primary: primary}:
  44. case <-returned:
  45. conn.Close()
  46. }
  47. }
  48. }
  49. primaryCtx, primaryCancel := context.WithCancel(ctx)
  50. defer primaryCancel()
  51. for _, iif := range primaryInterfaces {
  52. go startRacer(primaryCtx, true, iif)
  53. }
  54. var (
  55. fallbackTimer *time.Timer
  56. fallbackChan <-chan time.Time
  57. )
  58. if len(fallbackInterfaces) > 0 {
  59. fallbackTimer = time.NewTimer(fallbackDelay)
  60. defer fallbackTimer.Stop()
  61. fallbackChan = fallbackTimer.C
  62. }
  63. var errors []error
  64. for {
  65. select {
  66. case <-fallbackChan:
  67. fallbackCtx, fallbackCancel := context.WithCancel(ctx)
  68. defer fallbackCancel()
  69. for _, iif := range fallbackInterfaces {
  70. go startRacer(fallbackCtx, false, iif)
  71. }
  72. case res := <-results:
  73. if res.error == nil {
  74. return res.Conn, res.primary, nil
  75. }
  76. errors = append(errors, res.error)
  77. if len(errors) == len(primaryInterfaces)+len(fallbackInterfaces) {
  78. return nil, false, E.Errors(errors...)
  79. }
  80. if res.primary && fallbackTimer != nil && fallbackTimer.Stop() {
  81. fallbackTimer.Reset(0)
  82. }
  83. }
  84. }
  85. }
  86. func (d *DefaultDialer) dialParallelInterfaceFastFallback(ctx context.Context, dialer net.Dialer, network string, addr string, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration, resetFastFallback func(time.Time)) (net.Conn, bool, error) {
  87. primaryInterfaces, fallbackInterfaces := selectInterfaces(d.networkManager, strategy, interfaceType, fallbackInterfaceType)
  88. if len(primaryInterfaces)+len(fallbackInterfaces) == 0 {
  89. return nil, false, E.New("no available network interface")
  90. }
  91. defaultInterface := d.networkManager.InterfaceMonitor().DefaultInterface()
  92. if fallbackDelay == 0 {
  93. fallbackDelay = N.DefaultFallbackDelay
  94. }
  95. returned := make(chan struct{})
  96. defer close(returned)
  97. type dialResult struct {
  98. net.Conn
  99. error
  100. primary bool
  101. }
  102. startAt := time.Now()
  103. results := make(chan dialResult) // unbuffered
  104. startRacer := func(ctx context.Context, primary bool, iif adapter.NetworkInterface) {
  105. perNetDialer := dialer
  106. if defaultInterface == nil || iif.Index != defaultInterface.Index {
  107. perNetDialer.Control = control.Append(perNetDialer.Control, control.BindToInterface(nil, iif.Name, iif.Index))
  108. }
  109. conn, err := perNetDialer.DialContext(ctx, network, addr)
  110. if err != nil {
  111. select {
  112. case results <- dialResult{error: E.Cause(err, "dial ", iif.Name, " (", iif.Index, ")"), primary: primary}:
  113. case <-returned:
  114. }
  115. } else {
  116. select {
  117. case results <- dialResult{Conn: conn, primary: primary}:
  118. case <-returned:
  119. if primary && time.Since(startAt) <= fallbackDelay {
  120. resetFastFallback(time.Time{})
  121. }
  122. conn.Close()
  123. }
  124. }
  125. }
  126. for _, iif := range primaryInterfaces {
  127. go startRacer(ctx, true, iif)
  128. }
  129. fallbackCtx, fallbackCancel := context.WithCancel(ctx)
  130. defer fallbackCancel()
  131. for _, iif := range fallbackInterfaces {
  132. go startRacer(fallbackCtx, false, iif)
  133. }
  134. var errors []error
  135. for {
  136. select {
  137. case res := <-results:
  138. if res.error == nil {
  139. return res.Conn, res.primary, nil
  140. }
  141. errors = append(errors, res.error)
  142. if len(errors) == len(primaryInterfaces)+len(fallbackInterfaces) {
  143. return nil, false, E.Errors(errors...)
  144. }
  145. }
  146. }
  147. }
  148. func (d *DefaultDialer) listenSerialInterfacePacket(ctx context.Context, listener net.ListenConfig, network string, addr string, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) {
  149. primaryInterfaces, fallbackInterfaces := selectInterfaces(d.networkManager, strategy, interfaceType, fallbackInterfaceType)
  150. if len(primaryInterfaces)+len(fallbackInterfaces) == 0 {
  151. return nil, E.New("no available network interface")
  152. }
  153. defaultInterface := d.networkManager.InterfaceMonitor().DefaultInterface()
  154. var errors []error
  155. for _, primaryInterface := range primaryInterfaces {
  156. perNetListener := listener
  157. if defaultInterface == nil || primaryInterface.Index != defaultInterface.Index {
  158. perNetListener.Control = control.Append(perNetListener.Control, control.BindToInterface(nil, primaryInterface.Name, primaryInterface.Index))
  159. }
  160. conn, err := perNetListener.ListenPacket(ctx, network, addr)
  161. if err == nil {
  162. return conn, nil
  163. }
  164. errors = append(errors, E.Cause(err, "listen ", primaryInterface.Name, " (", primaryInterface.Index, ")"))
  165. }
  166. for _, fallbackInterface := range fallbackInterfaces {
  167. perNetListener := listener
  168. if defaultInterface == nil || fallbackInterface.Index != defaultInterface.Index {
  169. perNetListener.Control = control.Append(perNetListener.Control, control.BindToInterface(nil, fallbackInterface.Name, fallbackInterface.Index))
  170. }
  171. conn, err := perNetListener.ListenPacket(ctx, network, addr)
  172. if err == nil {
  173. return conn, nil
  174. }
  175. errors = append(errors, E.Cause(err, "listen ", fallbackInterface.Name, " (", fallbackInterface.Index, ")"))
  176. }
  177. return nil, E.Errors(errors...)
  178. }
  179. func selectInterfaces(networkManager adapter.NetworkManager, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType) (primaryInterfaces []adapter.NetworkInterface, fallbackInterfaces []adapter.NetworkInterface) {
  180. interfaces := networkManager.NetworkInterfaces()
  181. switch strategy {
  182. case C.NetworkStrategyDefault:
  183. if len(interfaceType) == 0 {
  184. defaultIf := networkManager.InterfaceMonitor().DefaultInterface()
  185. if defaultIf != nil {
  186. for _, iif := range interfaces {
  187. if iif.Index == defaultIf.Index {
  188. primaryInterfaces = append(primaryInterfaces, iif)
  189. }
  190. }
  191. } else {
  192. primaryInterfaces = interfaces
  193. }
  194. } else {
  195. primaryInterfaces = common.Filter(interfaces, func(it adapter.NetworkInterface) bool {
  196. return common.Contains(interfaceType, it.Type)
  197. })
  198. }
  199. case C.NetworkStrategyHybrid:
  200. if len(interfaceType) == 0 {
  201. primaryInterfaces = interfaces
  202. } else {
  203. primaryInterfaces = common.Filter(interfaces, func(it adapter.NetworkInterface) bool {
  204. return common.Contains(interfaceType, it.Type)
  205. })
  206. }
  207. case C.NetworkStrategyFallback:
  208. if len(interfaceType) == 0 {
  209. defaultIf := networkManager.InterfaceMonitor().DefaultInterface()
  210. if defaultIf != nil {
  211. for _, iif := range interfaces {
  212. if iif.Index == defaultIf.Index {
  213. primaryInterfaces = append(primaryInterfaces, iif)
  214. break
  215. }
  216. }
  217. } else {
  218. primaryInterfaces = interfaces
  219. }
  220. } else {
  221. primaryInterfaces = common.Filter(interfaces, func(it adapter.NetworkInterface) bool {
  222. return common.Contains(interfaceType, it.Type)
  223. })
  224. }
  225. if len(fallbackInterfaceType) == 0 {
  226. fallbackInterfaces = common.Filter(interfaces, func(it adapter.NetworkInterface) bool {
  227. return !common.Any(primaryInterfaces, func(iif adapter.NetworkInterface) bool {
  228. return it.Index == iif.Index
  229. })
  230. })
  231. } else {
  232. fallbackInterfaces = common.Filter(interfaces, func(iif adapter.NetworkInterface) bool {
  233. return common.Contains(fallbackInterfaceType, iif.Type)
  234. })
  235. }
  236. }
  237. return primaryInterfaces, fallbackInterfaces
  238. }