1
0

resolve.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. package dialer
  2. import (
  3. "context"
  4. "net"
  5. "sync"
  6. "time"
  7. "github.com/sagernet/sing-box/adapter"
  8. C "github.com/sagernet/sing-box/constant"
  9. "github.com/sagernet/sing-box/log"
  10. "github.com/sagernet/sing/common/bufio"
  11. E "github.com/sagernet/sing/common/exceptions"
  12. M "github.com/sagernet/sing/common/metadata"
  13. N "github.com/sagernet/sing/common/network"
  14. "github.com/sagernet/sing/service"
  15. )
  16. var (
  17. _ N.Dialer = (*resolveDialer)(nil)
  18. _ ParallelInterfaceDialer = (*resolveParallelNetworkDialer)(nil)
  19. )
  20. type ResolveDialer interface {
  21. N.Dialer
  22. QueryOptions() adapter.DNSQueryOptions
  23. }
  24. type ParallelInterfaceResolveDialer interface {
  25. ParallelInterfaceDialer
  26. QueryOptions() adapter.DNSQueryOptions
  27. }
  28. type resolveDialer struct {
  29. transport adapter.DNSTransportManager
  30. router adapter.DNSRouter
  31. dialer N.Dialer
  32. parallel bool
  33. server string
  34. initOnce sync.Once
  35. initErr error
  36. queryOptions adapter.DNSQueryOptions
  37. fallbackDelay time.Duration
  38. }
  39. func NewResolveDialer(ctx context.Context, dialer N.Dialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ResolveDialer {
  40. if parallelDialer, isParallel := dialer.(ParallelInterfaceDialer); isParallel {
  41. return &resolveParallelNetworkDialer{
  42. resolveDialer{
  43. transport: service.FromContext[adapter.DNSTransportManager](ctx),
  44. router: service.FromContext[adapter.DNSRouter](ctx),
  45. dialer: dialer,
  46. parallel: parallel,
  47. server: server,
  48. queryOptions: queryOptions,
  49. fallbackDelay: fallbackDelay,
  50. },
  51. parallelDialer,
  52. }
  53. }
  54. return &resolveDialer{
  55. transport: service.FromContext[adapter.DNSTransportManager](ctx),
  56. router: service.FromContext[adapter.DNSRouter](ctx),
  57. dialer: dialer,
  58. parallel: parallel,
  59. server: server,
  60. queryOptions: queryOptions,
  61. fallbackDelay: fallbackDelay,
  62. }
  63. }
  64. type resolveParallelNetworkDialer struct {
  65. resolveDialer
  66. dialer ParallelInterfaceDialer
  67. }
  68. func (d *resolveDialer) initialize() error {
  69. d.initOnce.Do(d.initServer)
  70. return d.initErr
  71. }
  72. func (d *resolveDialer) initServer() {
  73. if d.server == "" {
  74. return
  75. }
  76. transport, loaded := d.transport.Transport(d.server)
  77. if !loaded {
  78. d.initErr = E.New("domain resolver not found: " + d.server)
  79. return
  80. }
  81. d.queryOptions.Transport = transport
  82. }
  83. func (d *resolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  84. err := d.initialize()
  85. if err != nil {
  86. return nil, err
  87. }
  88. if !destination.IsFqdn() {
  89. return d.dialer.DialContext(ctx, network, destination)
  90. }
  91. ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
  92. addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions)
  93. if err != nil {
  94. return nil, err
  95. }
  96. if d.parallel {
  97. return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.queryOptions.Strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay)
  98. } else {
  99. return N.DialSerial(ctx, d.dialer, network, destination, addresses)
  100. }
  101. }
  102. func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  103. err := d.initialize()
  104. if err != nil {
  105. return nil, err
  106. }
  107. if !destination.IsFqdn() {
  108. return d.dialer.ListenPacket(ctx, destination)
  109. }
  110. ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
  111. addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions)
  112. if err != nil {
  113. return nil, err
  114. }
  115. conn, destinationAddress, err := N.ListenSerial(ctx, d.dialer, destination, addresses)
  116. if err != nil {
  117. return nil, err
  118. }
  119. return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
  120. }
  121. func (d *resolveDialer) QueryOptions() adapter.DNSQueryOptions {
  122. return d.queryOptions
  123. }
  124. func (d *resolveDialer) Upstream() any {
  125. return d.dialer
  126. }
  127. func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
  128. err := d.initialize()
  129. if err != nil {
  130. return nil, err
  131. }
  132. if !destination.IsFqdn() {
  133. return d.dialer.DialContext(ctx, network, destination)
  134. }
  135. ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
  136. addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions)
  137. if err != nil {
  138. return nil, err
  139. }
  140. if fallbackDelay == 0 {
  141. fallbackDelay = d.fallbackDelay
  142. }
  143. if d.parallel {
  144. return DialParallelNetwork(ctx, d.dialer, network, destination, addresses, d.queryOptions.Strategy == C.DomainStrategyPreferIPv6, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  145. } else {
  146. return DialSerialNetwork(ctx, d.dialer, network, destination, addresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  147. }
  148. }
  149. func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) {
  150. err := d.initialize()
  151. if err != nil {
  152. return nil, err
  153. }
  154. if !destination.IsFqdn() {
  155. return d.dialer.ListenPacket(ctx, destination)
  156. }
  157. ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
  158. addresses, err := d.router.Lookup(ctx, destination.Fqdn, d.queryOptions)
  159. if err != nil {
  160. return nil, err
  161. }
  162. if fallbackDelay == 0 {
  163. fallbackDelay = d.fallbackDelay
  164. }
  165. conn, destinationAddress, err := ListenSerialNetworkPacket(ctx, d.dialer, destination, addresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay)
  166. if err != nil {
  167. return nil, err
  168. }
  169. return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
  170. }
  171. func (d *resolveParallelNetworkDialer) QueryOptions() adapter.DNSQueryOptions {
  172. return d.queryOptions
  173. }
  174. func (d *resolveParallelNetworkDialer) Upstream() any {
  175. return d.dialer
  176. }