dns_transport.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. package tailscale
  2. import (
  3. "context"
  4. "net"
  5. "net/http"
  6. "net/netip"
  7. "net/url"
  8. "os"
  9. "reflect"
  10. "strings"
  11. "sync"
  12. "github.com/sagernet/sing-box/adapter"
  13. "github.com/sagernet/sing-box/common/dialer"
  14. "github.com/sagernet/sing-box/common/tls"
  15. C "github.com/sagernet/sing-box/constant"
  16. "github.com/sagernet/sing-box/dns"
  17. "github.com/sagernet/sing-box/dns/transport"
  18. "github.com/sagernet/sing-box/log"
  19. "github.com/sagernet/sing-box/option"
  20. "github.com/sagernet/sing/common"
  21. E "github.com/sagernet/sing/common/exceptions"
  22. "github.com/sagernet/sing/common/logger"
  23. M "github.com/sagernet/sing/common/metadata"
  24. N "github.com/sagernet/sing/common/network"
  25. "github.com/sagernet/sing/service"
  26. nDNS "github.com/sagernet/tailscale/net/dns"
  27. "github.com/sagernet/tailscale/types/dnstype"
  28. "github.com/sagernet/tailscale/wgengine/router"
  29. "github.com/sagernet/tailscale/wgengine/wgcfg"
  30. mDNS "github.com/miekg/dns"
  31. "go4.org/netipx"
  32. "golang.org/x/net/http2"
  33. )
  34. func RegistryTransport(registry *dns.TransportRegistry) {
  35. dns.RegisterTransport[option.TailscaleDNSServerOptions](registry, C.DNSTypeTailscale, NewDNSTransport)
  36. }
  37. type DNSTransport struct {
  38. dns.TransportAdapter
  39. ctx context.Context
  40. logger logger.ContextLogger
  41. endpointTag string
  42. acceptDefaultResolvers bool
  43. dnsRouter adapter.DNSRouter
  44. endpointManager adapter.EndpointManager
  45. cfg *wgcfg.Config
  46. dnsCfg *nDNS.Config
  47. endpoint *Endpoint
  48. routePrefixes []netip.Prefix
  49. routes map[string][]adapter.DNSTransport
  50. hosts map[string][]netip.Addr
  51. defaultResolvers []adapter.DNSTransport
  52. }
  53. func NewDNSTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.TailscaleDNSServerOptions) (adapter.DNSTransport, error) {
  54. if options.Endpoint == "" {
  55. return nil, E.New("missing tailscale endpoint tag")
  56. }
  57. return &DNSTransport{
  58. TransportAdapter: dns.NewTransportAdapter(C.DNSTypeTailscale, tag, nil),
  59. ctx: ctx,
  60. logger: logger,
  61. endpointTag: options.Endpoint,
  62. acceptDefaultResolvers: options.AcceptDefaultResolvers,
  63. dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
  64. endpointManager: service.FromContext[adapter.EndpointManager](ctx),
  65. }, nil
  66. }
  67. func (t *DNSTransport) Start(stage adapter.StartStage) error {
  68. if stage != adapter.StartStateInitialize {
  69. return nil
  70. }
  71. rawOutbound, loaded := t.endpointManager.Get(t.endpointTag)
  72. if !loaded {
  73. return E.New("endpoint not found: ", t.endpointTag)
  74. }
  75. ep, isTailscale := rawOutbound.(*Endpoint)
  76. if !isTailscale {
  77. return E.New("endpoint is not Tailscale: ", t.endpointTag)
  78. }
  79. if ep.onReconfig != nil {
  80. return E.New("only one Tailscale DNS server is allowed for single endpoint")
  81. }
  82. ep.onReconfig = t.onReconfig
  83. t.endpoint = ep
  84. return nil
  85. }
  86. func (t *DNSTransport) Reset() {
  87. }
  88. func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *nDNS.Config) {
  89. if cfg == nil || dnsCfg == nil {
  90. return
  91. }
  92. if (t.cfg != nil && reflect.DeepEqual(t.cfg, cfg)) && (t.dnsCfg != nil && reflect.DeepEqual(t.dnsCfg, dnsCfg)) {
  93. return
  94. }
  95. t.cfg = cfg
  96. t.dnsCfg = dnsCfg
  97. err := t.updateDNSServers(routerCfg, dnsCfg)
  98. if err != nil {
  99. t.logger.Error(E.Cause(err, "update DNS servers"))
  100. }
  101. }
  102. func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *nDNS.Config) error {
  103. t.routePrefixes = buildRoutePrefixes(routeConfig)
  104. directDialerOnce := sync.OnceValue(func() N.Dialer {
  105. directDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{}))
  106. return &DNSDialer{transport: t, fallbackDialer: directDialer}
  107. })
  108. routes := make(map[string][]adapter.DNSTransport)
  109. for domain, resolvers := range dnsConfig.Routes {
  110. var myResolvers []adapter.DNSTransport
  111. for _, resolver := range resolvers {
  112. myResolver, err := t.createResolver(directDialerOnce, resolver)
  113. if err != nil {
  114. return err
  115. }
  116. myResolvers = append(myResolvers, myResolver)
  117. }
  118. routes[domain.WithTrailingDot()] = myResolvers
  119. }
  120. hosts := make(map[string][]netip.Addr)
  121. for domain, addresses := range dnsConfig.Hosts {
  122. hosts[domain.WithTrailingDot()] = addresses
  123. }
  124. var defaultResolvers []adapter.DNSTransport
  125. for _, resolver := range dnsConfig.DefaultResolvers {
  126. myResolver, err := t.createResolver(directDialerOnce, resolver)
  127. if err != nil {
  128. return err
  129. }
  130. defaultResolvers = append(defaultResolvers, myResolver)
  131. }
  132. t.routes = routes
  133. t.hosts = hosts
  134. t.defaultResolvers = defaultResolvers
  135. if len(defaultResolvers) > 0 {
  136. t.logger.Info("updated ", len(routes), " routes, ", len(hosts), " hosts, default resolvers: ",
  137. strings.Join(common.Map(dnsConfig.DefaultResolvers, func(it *dnstype.Resolver) string { return it.Addr }), " "))
  138. } else {
  139. t.logger.Info("updated ", len(routes), " routes, ", len(hosts), " hosts")
  140. }
  141. return nil
  142. }
  143. func (t *DNSTransport) createResolver(directDialer func() N.Dialer, resolver *dnstype.Resolver) (adapter.DNSTransport, error) {
  144. serverURL, parseURLErr := url.Parse(resolver.Addr)
  145. var myDialer N.Dialer
  146. if parseURLErr == nil && serverURL.Scheme == "http" {
  147. myDialer = t.endpoint
  148. } else {
  149. myDialer = directDialer()
  150. }
  151. if len(resolver.BootstrapResolution) > 0 {
  152. bootstrapTransport := transport.NewUDPRaw(t.logger, t.TransportAdapter, myDialer, M.SocksaddrFrom(resolver.BootstrapResolution[0], 53))
  153. myDialer = dialer.NewResolveDialer(t.ctx, myDialer, false, "", adapter.DNSQueryOptions{Transport: bootstrapTransport}, 0)
  154. }
  155. if serverAddr := M.ParseSocksaddr(resolver.Addr); serverAddr.IsValid() {
  156. if serverAddr.Port == 0 {
  157. serverAddr.Port = 53
  158. }
  159. return transport.NewUDPRaw(t.logger, t.TransportAdapter, myDialer, serverAddr), nil
  160. } else if parseURLErr != nil {
  161. return nil, E.Cause(parseURLErr, "parse resolver address")
  162. } else {
  163. switch serverURL.Scheme {
  164. case "https":
  165. serverAddr = M.ParseSocksaddrHostPortStr(serverURL.Hostname(), serverURL.Port())
  166. if serverAddr.Port == 0 {
  167. serverAddr.Port = 443
  168. }
  169. tlsConfig := common.Must1(tls.NewClient(t.ctx, serverAddr.AddrString(), option.OutboundTLSOptions{
  170. ALPN: []string{http2.NextProtoTLS, "http/1.1"},
  171. }))
  172. return transport.NewHTTPSRaw(t.TransportAdapter, t.logger, myDialer, serverURL, http.Header{}, serverAddr, tlsConfig), nil
  173. case "http":
  174. serverAddr = M.ParseSocksaddrHostPortStr(serverURL.Hostname(), serverURL.Port())
  175. if serverAddr.Port == 0 {
  176. serverAddr.Port = 80
  177. }
  178. return transport.NewHTTPSRaw(t.TransportAdapter, t.logger, myDialer, serverURL, http.Header{}, serverAddr, nil), nil
  179. // case "tls":
  180. default:
  181. return nil, E.New("unknown resolver scheme: ", serverURL.Scheme)
  182. }
  183. }
  184. }
  185. func buildRoutePrefixes(routeConfig *router.Config) []netip.Prefix {
  186. var builder netipx.IPSetBuilder
  187. for _, localAddr := range routeConfig.LocalAddrs {
  188. builder.AddPrefix(localAddr)
  189. }
  190. for _, route := range routeConfig.Routes {
  191. builder.AddPrefix(route)
  192. }
  193. for _, route := range routeConfig.LocalRoutes {
  194. builder.AddPrefix(route)
  195. }
  196. for _, route := range routeConfig.SubnetRoutes {
  197. builder.AddPrefix(route)
  198. }
  199. ipSet, err := builder.IPSet()
  200. if err != nil {
  201. return nil
  202. }
  203. return ipSet.Prefixes()
  204. }
  205. func (t *DNSTransport) Close() error {
  206. return nil
  207. }
  208. func (t *DNSTransport) Raw() bool {
  209. return true
  210. }
  211. func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
  212. if len(message.Question) != 1 {
  213. return nil, os.ErrInvalid
  214. }
  215. question := message.Question[0]
  216. addresses, hostsLoaded := t.hosts[question.Name]
  217. if hostsLoaded {
  218. switch question.Qtype {
  219. case mDNS.TypeA:
  220. addresses4 := common.Filter(addresses, func(addr netip.Addr) bool {
  221. return addr.Is4()
  222. })
  223. if len(addresses4) > 0 {
  224. return dns.FixedResponse(message.Id, question, addresses4, C.DefaultDNSTTL), nil
  225. }
  226. case mDNS.TypeAAAA:
  227. addresses6 := common.Filter(addresses, func(addr netip.Addr) bool {
  228. return addr.Is6()
  229. })
  230. if len(addresses6) > 0 {
  231. return dns.FixedResponse(message.Id, question, addresses6, C.DefaultDNSTTL), nil
  232. }
  233. }
  234. }
  235. for domainSuffix, transports := range t.routes {
  236. if strings.HasSuffix(question.Name, domainSuffix) {
  237. if len(transports) == 0 {
  238. return &mDNS.Msg{
  239. MsgHdr: mDNS.MsgHdr{
  240. Id: message.Id,
  241. Rcode: mDNS.RcodeNameError,
  242. Response: true,
  243. },
  244. Question: []mDNS.Question{question},
  245. }, nil
  246. }
  247. var lastErr error
  248. for _, dnsTransport := range transports {
  249. response, err := dnsTransport.Exchange(ctx, message)
  250. if err != nil {
  251. lastErr = err
  252. continue
  253. }
  254. return response, nil
  255. }
  256. return nil, lastErr
  257. }
  258. }
  259. if t.acceptDefaultResolvers {
  260. if len(t.defaultResolvers) > 0 {
  261. var lastErr error
  262. for _, resolver := range t.defaultResolvers {
  263. response, err := resolver.Exchange(ctx, message)
  264. if err != nil {
  265. lastErr = err
  266. continue
  267. }
  268. return response, nil
  269. }
  270. return nil, lastErr
  271. } else {
  272. return nil, E.New("missing default resolvers")
  273. }
  274. }
  275. return nil, dns.RcodeNameError
  276. }
  277. type DNSDialer struct {
  278. transport *DNSTransport
  279. fallbackDialer N.Dialer
  280. }
  281. func (d *DNSDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  282. if destination.IsFqdn() {
  283. panic("invalid request here")
  284. }
  285. for _, prefix := range d.transport.routePrefixes {
  286. if prefix.Contains(destination.Addr) {
  287. return d.transport.endpoint.DialContext(ctx, network, destination)
  288. }
  289. }
  290. return d.fallbackDialer.DialContext(ctx, network, destination)
  291. }
  292. func (d *DNSDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  293. if destination.IsFqdn() {
  294. panic("invalid request here")
  295. }
  296. for _, prefix := range d.transport.routePrefixes {
  297. if prefix.Contains(destination.Addr) {
  298. return d.transport.endpoint.ListenPacket(ctx, destination)
  299. }
  300. }
  301. return d.fallbackDialer.ListenPacket(ctx, destination)
  302. }