transport.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. //go:build linux
  2. package resolved
  3. import (
  4. "context"
  5. "net/netip"
  6. "os"
  7. "strings"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. "github.com/sagernet/sing-box/adapter"
  12. "github.com/sagernet/sing-box/common/dialer"
  13. "github.com/sagernet/sing-box/common/tls"
  14. C "github.com/sagernet/sing-box/constant"
  15. "github.com/sagernet/sing-box/dns"
  16. "github.com/sagernet/sing-box/dns/transport"
  17. "github.com/sagernet/sing-box/log"
  18. "github.com/sagernet/sing-box/option"
  19. "github.com/sagernet/sing/common"
  20. E "github.com/sagernet/sing/common/exceptions"
  21. "github.com/sagernet/sing/common/logger"
  22. M "github.com/sagernet/sing/common/metadata"
  23. "github.com/sagernet/sing/service"
  24. mDNS "github.com/miekg/dns"
  25. )
  26. func RegisterTransport(registry *dns.TransportRegistry) {
  27. dns.RegisterTransport[option.ResolvedDNSServerOptions](registry, C.TypeResolved, NewTransport)
  28. }
  29. var _ adapter.DNSTransport = (*Transport)(nil)
  30. type Transport struct {
  31. dns.TransportAdapter
  32. ctx context.Context
  33. logger logger.ContextLogger
  34. serviceTag string
  35. acceptDefaultResolvers bool
  36. ndots int
  37. timeout time.Duration
  38. attempts int
  39. rotate bool
  40. service *Service
  41. linkAccess sync.RWMutex
  42. linkServers map[*TransportLink]*LinkServers
  43. }
  44. type LinkServers struct {
  45. Link *TransportLink
  46. Servers []adapter.DNSTransport
  47. serverOffset uint32
  48. }
  49. func (c *LinkServers) ServerOffset(rotate bool) uint32 {
  50. if rotate {
  51. return atomic.AddUint32(&c.serverOffset, 1) - 1
  52. }
  53. return 0
  54. }
  55. func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.ResolvedDNSServerOptions) (adapter.DNSTransport, error) {
  56. return &Transport{
  57. TransportAdapter: dns.NewTransportAdapter(C.DNSTypeDHCP, tag, nil),
  58. ctx: ctx,
  59. logger: logger,
  60. serviceTag: options.Service,
  61. acceptDefaultResolvers: options.AcceptDefaultResolvers,
  62. // ndots: options.NDots,
  63. // timeout: time.Duration(options.Timeout),
  64. // attempts: options.Attempts,
  65. // rotate: options.Rotate,
  66. ndots: 1,
  67. timeout: 5 * time.Second,
  68. attempts: 2,
  69. linkServers: make(map[*TransportLink]*LinkServers),
  70. }, nil
  71. }
  72. func (t *Transport) Start(stage adapter.StartStage) error {
  73. if stage != adapter.StartStateInitialize {
  74. return nil
  75. }
  76. serviceManager := service.FromContext[adapter.ServiceManager](t.ctx)
  77. service, loaded := serviceManager.Get(t.serviceTag)
  78. if !loaded {
  79. return E.New("service not found: ", t.serviceTag)
  80. }
  81. resolvedInbound, isResolved := service.(*Service)
  82. if !isResolved {
  83. return E.New("service is not resolved: ", t.serviceTag)
  84. }
  85. resolvedInbound.updateCallback = t.updateTransports
  86. resolvedInbound.deleteCallback = t.deleteTransport
  87. t.service = resolvedInbound
  88. return nil
  89. }
  90. func (t *Transport) Close() error {
  91. t.linkAccess.RLock()
  92. defer t.linkAccess.RUnlock()
  93. for _, servers := range t.linkServers {
  94. for _, server := range servers.Servers {
  95. server.Close()
  96. }
  97. }
  98. return nil
  99. }
  100. func (t *Transport) updateTransports(link *TransportLink) error {
  101. t.linkAccess.Lock()
  102. defer t.linkAccess.Unlock()
  103. if servers, loaded := t.linkServers[link]; loaded {
  104. for _, server := range servers.Servers {
  105. server.Close()
  106. }
  107. }
  108. serverDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{
  109. BindInterface: link.iif.Name,
  110. UDPFragmentDefault: true,
  111. }))
  112. var transports []adapter.DNSTransport
  113. for _, address := range link.address {
  114. serverAddr, ok := netip.AddrFromSlice(address.Address)
  115. if !ok {
  116. return os.ErrInvalid
  117. }
  118. if link.dnsOverTLS {
  119. tlsConfig := common.Must1(tls.NewClient(t.ctx, t.logger, serverAddr.String(), option.OutboundTLSOptions{
  120. Enabled: true,
  121. ServerName: serverAddr.String(),
  122. }))
  123. transports = append(transports, transport.NewTLSRaw(t.logger, t.TransportAdapter, serverDialer, M.SocksaddrFrom(serverAddr, 53), tlsConfig))
  124. } else {
  125. transports = append(transports, transport.NewUDPRaw(t.logger, t.TransportAdapter, serverDialer, M.SocksaddrFrom(serverAddr, 53)))
  126. }
  127. }
  128. for _, address := range link.addressEx {
  129. serverAddr, ok := netip.AddrFromSlice(address.Address)
  130. if !ok {
  131. return os.ErrInvalid
  132. }
  133. if link.dnsOverTLS {
  134. var serverName string
  135. if address.Name != "" {
  136. serverName = address.Name
  137. } else {
  138. serverName = serverAddr.String()
  139. }
  140. tlsConfig := common.Must1(tls.NewClient(t.ctx, t.logger, serverAddr.String(), option.OutboundTLSOptions{
  141. Enabled: true,
  142. ServerName: serverName,
  143. }))
  144. transports = append(transports, transport.NewTLSRaw(t.logger, t.TransportAdapter, serverDialer, M.SocksaddrFrom(serverAddr, address.Port), tlsConfig))
  145. } else {
  146. transports = append(transports, transport.NewUDPRaw(t.logger, t.TransportAdapter, serverDialer, M.SocksaddrFrom(serverAddr, address.Port)))
  147. }
  148. }
  149. t.linkServers[link] = &LinkServers{
  150. Link: link,
  151. Servers: transports,
  152. }
  153. return nil
  154. }
  155. func (t *Transport) deleteTransport(link *TransportLink) {
  156. t.linkAccess.Lock()
  157. defer t.linkAccess.Unlock()
  158. servers, loaded := t.linkServers[link]
  159. if !loaded {
  160. return
  161. }
  162. for _, server := range servers.Servers {
  163. server.Close()
  164. }
  165. delete(t.linkServers, link)
  166. }
  167. func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
  168. question := message.Question[0]
  169. var selectedLink *TransportLink
  170. t.service.linkAccess.RLock()
  171. for _, link := range t.service.links {
  172. for _, domain := range link.domain {
  173. if domain.Domain == "." && domain.RoutingOnly && !t.acceptDefaultResolvers {
  174. continue
  175. }
  176. if strings.HasSuffix(question.Name, domain.Domain) {
  177. selectedLink = link
  178. }
  179. }
  180. }
  181. if selectedLink == nil && t.acceptDefaultResolvers {
  182. for l := len(t.service.defaultRouteSequence); l > 0; l-- {
  183. selectedLink = t.service.links[t.service.defaultRouteSequence[l-1]]
  184. if len(selectedLink.address) > 0 || len(selectedLink.addressEx) > 0 {
  185. break
  186. }
  187. }
  188. }
  189. t.service.linkAccess.RUnlock()
  190. if selectedLink == nil {
  191. return dns.FixedResponseStatus(message, mDNS.RcodeNameError), nil
  192. }
  193. t.linkAccess.RLock()
  194. servers := t.linkServers[selectedLink]
  195. t.linkAccess.RUnlock()
  196. if len(servers.Servers) == 0 {
  197. return dns.FixedResponseStatus(message, mDNS.RcodeNameError), nil
  198. }
  199. if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
  200. return t.exchangeParallel(ctx, servers, message)
  201. } else {
  202. return t.exchangeSingleRequest(ctx, servers, message)
  203. }
  204. }
  205. func (t *Transport) exchangeSingleRequest(ctx context.Context, servers *LinkServers, message *mDNS.Msg) (*mDNS.Msg, error) {
  206. var lastErr error
  207. for _, fqdn := range servers.Link.nameList(t.ndots, message.Question[0].Name) {
  208. response, err := t.tryOneName(ctx, servers, message, fqdn)
  209. if err != nil {
  210. lastErr = err
  211. continue
  212. }
  213. return response, nil
  214. }
  215. return nil, lastErr
  216. }
  217. func (t *Transport) tryOneName(ctx context.Context, servers *LinkServers, message *mDNS.Msg, fqdn string) (*mDNS.Msg, error) {
  218. serverOffset := servers.ServerOffset(t.rotate)
  219. sLen := uint32(len(servers.Servers))
  220. var lastErr error
  221. for i := 0; i < t.attempts; i++ {
  222. for j := uint32(0); j < sLen; j++ {
  223. server := servers.Servers[(serverOffset+j)%sLen]
  224. question := message.Question[0]
  225. question.Name = fqdn
  226. exchangeMessage := *message
  227. exchangeMessage.Question = []mDNS.Question{question}
  228. exchangeCtx, cancel := context.WithTimeout(ctx, t.timeout)
  229. response, err := server.Exchange(exchangeCtx, &exchangeMessage)
  230. cancel()
  231. if err != nil {
  232. lastErr = err
  233. continue
  234. }
  235. return response, nil
  236. }
  237. }
  238. return nil, E.Cause(lastErr, fqdn)
  239. }
  240. func (t *Transport) exchangeParallel(ctx context.Context, servers *LinkServers, message *mDNS.Msg) (*mDNS.Msg, error) {
  241. returned := make(chan struct{})
  242. defer close(returned)
  243. type queryResult struct {
  244. response *mDNS.Msg
  245. err error
  246. }
  247. results := make(chan queryResult)
  248. startRacer := func(ctx context.Context, fqdn string) {
  249. response, err := t.tryOneName(ctx, servers, message, fqdn)
  250. select {
  251. case results <- queryResult{response, err}:
  252. case <-returned:
  253. }
  254. }
  255. queryCtx, queryCancel := context.WithCancel(ctx)
  256. defer queryCancel()
  257. var nameCount int
  258. for _, fqdn := range servers.Link.nameList(t.ndots, message.Question[0].Name) {
  259. nameCount++
  260. go startRacer(queryCtx, fqdn)
  261. }
  262. var errors []error
  263. for {
  264. select {
  265. case <-ctx.Done():
  266. return nil, ctx.Err()
  267. case result := <-results:
  268. if result.err == nil {
  269. return result.response, nil
  270. }
  271. errors = append(errors, result.err)
  272. if len(errors) == nameCount {
  273. return nil, E.Errors(errors...)
  274. }
  275. }
  276. }
  277. }