dns.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. package dns
  2. import (
  3. "context"
  4. go_errors "errors"
  5. "io"
  6. "sync"
  7. "time"
  8. "github.com/xtls/xray-core/common"
  9. "github.com/xtls/xray-core/common/buf"
  10. "github.com/xtls/xray-core/common/errors"
  11. "github.com/xtls/xray-core/common/net"
  12. dns_proto "github.com/xtls/xray-core/common/protocol/dns"
  13. "github.com/xtls/xray-core/common/session"
  14. "github.com/xtls/xray-core/common/signal"
  15. "github.com/xtls/xray-core/common/task"
  16. "github.com/xtls/xray-core/core"
  17. "github.com/xtls/xray-core/features/dns"
  18. "github.com/xtls/xray-core/features/policy"
  19. "github.com/xtls/xray-core/transport"
  20. "github.com/xtls/xray-core/transport/internet"
  21. "github.com/xtls/xray-core/transport/internet/stat"
  22. "golang.org/x/net/dns/dnsmessage"
  23. )
  24. func init() {
  25. common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
  26. h := new(Handler)
  27. if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error {
  28. core.OptionalFeatures(ctx, func(fdns dns.FakeDNSEngine) {
  29. h.fdns = fdns
  30. })
  31. return h.Init(config.(*Config), dnsClient, policyManager)
  32. }); err != nil {
  33. return nil, err
  34. }
  35. return h, nil
  36. }))
  37. }
  38. type ownLinkVerifier interface {
  39. IsOwnLink(ctx context.Context) bool
  40. }
  41. type Handler struct {
  42. client dns.Client
  43. fdns dns.FakeDNSEngine
  44. ownLinkVerifier ownLinkVerifier
  45. server net.Destination
  46. timeout time.Duration
  47. nonIPQuery string
  48. blockTypes []int32
  49. }
  50. func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error {
  51. h.client = dnsClient
  52. h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle
  53. if v, ok := dnsClient.(ownLinkVerifier); ok {
  54. h.ownLinkVerifier = v
  55. }
  56. if config.Server != nil {
  57. h.server = config.Server.AsDestination()
  58. }
  59. h.nonIPQuery = config.Non_IPQuery
  60. h.blockTypes = config.BlockTypes
  61. return nil
  62. }
  63. func (h *Handler) isOwnLink(ctx context.Context) bool {
  64. return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx)
  65. }
  66. func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) {
  67. var parser dnsmessage.Parser
  68. header, err := parser.Start(b)
  69. if err != nil {
  70. errors.LogInfoInner(context.Background(), err, "parser start")
  71. return
  72. }
  73. id = header.ID
  74. q, err := parser.Question()
  75. if err != nil {
  76. errors.LogInfoInner(context.Background(), err, "question")
  77. return
  78. }
  79. domain = q.Name.String()
  80. qType = q.Type
  81. if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA {
  82. return
  83. }
  84. r = true
  85. return
  86. }
  87. // Process implements proxy.Outbound.
  88. func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error {
  89. outbounds := session.OutboundsFromContext(ctx)
  90. ob := outbounds[len(outbounds)-1]
  91. if !ob.Target.IsValid() {
  92. return errors.New("invalid outbound")
  93. }
  94. ob.Name = "dns"
  95. srcNetwork := ob.Target.Network
  96. dest := ob.Target
  97. if h.server.Network != net.Network_Unknown {
  98. dest.Network = h.server.Network
  99. }
  100. if h.server.Address != nil {
  101. dest.Address = h.server.Address
  102. }
  103. if h.server.Port != 0 {
  104. dest.Port = h.server.Port
  105. }
  106. errors.LogInfo(ctx, "handling DNS traffic to ", dest)
  107. conn := &outboundConn{
  108. dialer: func() (stat.Connection, error) {
  109. return d.Dial(ctx, dest)
  110. },
  111. connReady: make(chan struct{}, 1),
  112. }
  113. var reader dns_proto.MessageReader
  114. var writer dns_proto.MessageWriter
  115. if srcNetwork == net.Network_TCP {
  116. reader = dns_proto.NewTCPReader(link.Reader)
  117. writer = &dns_proto.TCPWriter{
  118. Writer: link.Writer,
  119. }
  120. } else {
  121. reader = &dns_proto.UDPReader{
  122. Reader: link.Reader,
  123. }
  124. writer = &dns_proto.UDPWriter{
  125. Writer: link.Writer,
  126. }
  127. }
  128. var connReader dns_proto.MessageReader
  129. var connWriter dns_proto.MessageWriter
  130. if dest.Network == net.Network_TCP {
  131. connReader = dns_proto.NewTCPReader(buf.NewReader(conn))
  132. connWriter = &dns_proto.TCPWriter{
  133. Writer: buf.NewWriter(conn),
  134. }
  135. } else {
  136. connReader = &dns_proto.UDPReader{
  137. Reader: buf.NewPacketReader(conn),
  138. }
  139. connWriter = &dns_proto.UDPWriter{
  140. Writer: buf.NewWriter(conn),
  141. }
  142. }
  143. if session.TimeoutOnlyFromContext(ctx) {
  144. ctx, _ = context.WithCancel(context.Background())
  145. }
  146. ctx, cancel := context.WithCancel(ctx)
  147. timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout)
  148. request := func() error {
  149. defer conn.Close()
  150. for {
  151. b, err := reader.ReadMessage()
  152. if err == io.EOF {
  153. return nil
  154. }
  155. if err != nil {
  156. return err
  157. }
  158. timer.Update()
  159. if !h.isOwnLink(ctx) {
  160. isIPQuery, domain, id, qType := parseIPQuery(b.Bytes())
  161. if len(h.blockTypes) > 0 {
  162. for _, blocktype := range h.blockTypes {
  163. if blocktype == int32(qType) {
  164. if h.nonIPQuery == "reject" {
  165. go h.rejectNonIPQuery(id, qType, domain, writer)
  166. }
  167. errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain)
  168. return nil
  169. }
  170. }
  171. }
  172. if isIPQuery {
  173. go h.handleIPQuery(id, qType, domain, writer)
  174. }
  175. if isIPQuery || h.nonIPQuery == "drop" {
  176. b.Release()
  177. continue
  178. }
  179. if h.nonIPQuery == "reject" {
  180. go h.rejectNonIPQuery(id, qType, domain, writer)
  181. b.Release()
  182. continue
  183. }
  184. }
  185. if err := connWriter.WriteMessage(b); err != nil {
  186. return err
  187. }
  188. }
  189. }
  190. response := func() error {
  191. for {
  192. b, err := connReader.ReadMessage()
  193. if err == io.EOF {
  194. return nil
  195. }
  196. if err != nil {
  197. return err
  198. }
  199. timer.Update()
  200. if err := writer.WriteMessage(b); err != nil {
  201. return err
  202. }
  203. }
  204. }
  205. if err := task.Run(ctx, request, response); err != nil {
  206. return errors.New("connection ends").Base(err)
  207. }
  208. return nil
  209. }
  210. func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
  211. var ips []net.IP
  212. var err error
  213. var ttl4 uint32
  214. var ttl6 uint32
  215. switch qType {
  216. case dnsmessage.TypeA:
  217. ips, ttl4, err = h.client.LookupIP(domain, dns.IPOption{
  218. IPv4Enable: true,
  219. IPv6Enable: false,
  220. FakeEnable: true,
  221. })
  222. case dnsmessage.TypeAAAA:
  223. ips, ttl6, err = h.client.LookupIP(domain, dns.IPOption{
  224. IPv4Enable: false,
  225. IPv6Enable: true,
  226. FakeEnable: true,
  227. })
  228. }
  229. rcode := dns.RCodeFromError(err)
  230. if rcode == 0 && len(ips) == 0 && !go_errors.Is(err, dns.ErrEmptyResponse) {
  231. errors.LogInfoInner(context.Background(), err, "ip query")
  232. return
  233. }
  234. switch qType {
  235. case dnsmessage.TypeA:
  236. for i, ip := range ips {
  237. ips[i] = ip.To4()
  238. }
  239. case dnsmessage.TypeAAAA:
  240. for i, ip := range ips {
  241. ips[i] = ip.To16()
  242. }
  243. }
  244. b := buf.New()
  245. rawBytes := b.Extend(buf.Size)
  246. builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
  247. ID: id,
  248. RCode: dnsmessage.RCode(rcode),
  249. RecursionAvailable: true,
  250. RecursionDesired: true,
  251. Response: true,
  252. Authoritative: true,
  253. })
  254. builder.EnableCompression()
  255. common.Must(builder.StartQuestions())
  256. common.Must(builder.Question(dnsmessage.Question{
  257. Name: dnsmessage.MustNewName(domain),
  258. Class: dnsmessage.ClassINET,
  259. Type: qType,
  260. }))
  261. common.Must(builder.StartAnswers())
  262. rHeader4 := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl4}
  263. rHeader6 := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl6}
  264. for _, ip := range ips {
  265. if len(ip) == net.IPv4len {
  266. var r dnsmessage.AResource
  267. copy(r.A[:], ip)
  268. common.Must(builder.AResource(rHeader4, r))
  269. } else {
  270. var r dnsmessage.AAAAResource
  271. copy(r.AAAA[:], ip)
  272. common.Must(builder.AAAAResource(rHeader6, r))
  273. }
  274. }
  275. msgBytes, err := builder.Finish()
  276. if err != nil {
  277. errors.LogInfoInner(context.Background(), err, "pack message")
  278. b.Release()
  279. return
  280. }
  281. b.Resize(0, int32(len(msgBytes)))
  282. if err := writer.WriteMessage(b); err != nil {
  283. errors.LogInfoInner(context.Background(), err, "write IP answer")
  284. }
  285. }
  286. func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
  287. b := buf.New()
  288. rawBytes := b.Extend(buf.Size)
  289. builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
  290. ID: id,
  291. RCode: dnsmessage.RCodeRefused,
  292. RecursionAvailable: true,
  293. RecursionDesired: true,
  294. Response: true,
  295. Authoritative: true,
  296. })
  297. builder.EnableCompression()
  298. common.Must(builder.StartQuestions())
  299. err := builder.Question(dnsmessage.Question{
  300. Name: dnsmessage.MustNewName(domain),
  301. Class: dnsmessage.ClassINET,
  302. Type: qType,
  303. })
  304. if err != nil {
  305. errors.LogInfo(context.Background(), "unexpected domain ", domain, " when building reject message: ", err)
  306. b.Release()
  307. return
  308. }
  309. msgBytes, err := builder.Finish()
  310. if err != nil {
  311. errors.LogInfoInner(context.Background(), err, "pack reject message")
  312. b.Release()
  313. return
  314. }
  315. b.Resize(0, int32(len(msgBytes)))
  316. if err := writer.WriteMessage(b); err != nil {
  317. errors.LogInfoInner(context.Background(), err, "write reject answer")
  318. }
  319. }
  320. type outboundConn struct {
  321. access sync.Mutex
  322. dialer func() (stat.Connection, error)
  323. conn net.Conn
  324. connReady chan struct{}
  325. }
  326. func (c *outboundConn) dial() error {
  327. conn, err := c.dialer()
  328. if err != nil {
  329. return err
  330. }
  331. c.conn = conn
  332. c.connReady <- struct{}{}
  333. return nil
  334. }
  335. func (c *outboundConn) Write(b []byte) (int, error) {
  336. c.access.Lock()
  337. if c.conn == nil {
  338. if err := c.dial(); err != nil {
  339. c.access.Unlock()
  340. errors.LogWarningInner(context.Background(), err, "failed to dial outbound connection")
  341. return len(b), nil
  342. }
  343. }
  344. c.access.Unlock()
  345. return c.conn.Write(b)
  346. }
  347. func (c *outboundConn) Read(b []byte) (int, error) {
  348. var conn net.Conn
  349. c.access.Lock()
  350. conn = c.conn
  351. c.access.Unlock()
  352. if conn == nil {
  353. _, open := <-c.connReady
  354. if !open {
  355. return 0, io.EOF
  356. }
  357. conn = c.conn
  358. }
  359. return conn.Read(b)
  360. }
  361. func (c *outboundConn) Close() error {
  362. c.access.Lock()
  363. close(c.connReady)
  364. if c.conn != nil {
  365. c.conn.Close()
  366. }
  367. c.access.Unlock()
  368. return nil
  369. }