dnscommon.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. package dns
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "strings"
  6. "time"
  7. "github.com/xtls/xray-core/common"
  8. "github.com/xtls/xray-core/common/errors"
  9. "github.com/xtls/xray-core/common/log"
  10. "github.com/xtls/xray-core/common/net"
  11. "github.com/xtls/xray-core/common/session"
  12. "github.com/xtls/xray-core/core"
  13. dns_feature "github.com/xtls/xray-core/features/dns"
  14. "golang.org/x/net/dns/dnsmessage"
  15. )
  16. // Fqdn normalizes domain make sure it ends with '.'
  17. func Fqdn(domain string) string {
  18. if len(domain) > 0 && strings.HasSuffix(domain, ".") {
  19. return domain
  20. }
  21. return domain + "."
  22. }
  23. type record struct {
  24. A *IPRecord
  25. AAAA *IPRecord
  26. }
  27. // IPRecord is a cacheable item for a resolved domain
  28. type IPRecord struct {
  29. ReqID uint16
  30. IP []net.Address
  31. Expire time.Time
  32. RCode dnsmessage.RCode
  33. }
  34. func (r *IPRecord) getIPs() ([]net.Address, error) {
  35. if r == nil || r.Expire.Before(time.Now()) {
  36. return nil, errRecordNotFound
  37. }
  38. if r.RCode != dnsmessage.RCodeSuccess {
  39. return nil, dns_feature.RCodeError(r.RCode)
  40. }
  41. return r.IP, nil
  42. }
  43. func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
  44. if newRec == nil {
  45. return false
  46. }
  47. if baseRec == nil {
  48. return true
  49. }
  50. return baseRec.Expire.Before(newRec.Expire)
  51. }
  52. var errRecordNotFound = errors.New("record not found")
  53. type dnsRequest struct {
  54. reqType dnsmessage.Type
  55. domain string
  56. start time.Time
  57. expire time.Time
  58. msg *dnsmessage.Message
  59. }
  60. func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource {
  61. if len(clientIP) == 0 {
  62. return nil
  63. }
  64. var netmask int
  65. var family uint16
  66. if len(clientIP) == 4 {
  67. family = 1
  68. netmask = 24 // 24 for IPV4, 96 for IPv6
  69. } else {
  70. family = 2
  71. netmask = 96
  72. }
  73. b := make([]byte, 4)
  74. binary.BigEndian.PutUint16(b[0:], family)
  75. b[2] = byte(netmask)
  76. b[3] = 0
  77. switch family {
  78. case 1:
  79. ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
  80. needLength := (netmask + 8 - 1) / 8 // division rounding up
  81. b = append(b, ip[:needLength]...)
  82. case 2:
  83. ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
  84. needLength := (netmask + 8 - 1) / 8 // division rounding up
  85. b = append(b, ip[:needLength]...)
  86. }
  87. const EDNS0SUBNET = 0x08
  88. opt := new(dnsmessage.Resource)
  89. common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
  90. opt.Body = &dnsmessage.OPTResource{
  91. Options: []dnsmessage.Option{
  92. {
  93. Code: EDNS0SUBNET,
  94. Data: b,
  95. },
  96. },
  97. }
  98. return opt
  99. }
  100. func buildReqMsgs(domain string, option dns_feature.IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
  101. qA := dnsmessage.Question{
  102. Name: dnsmessage.MustNewName(domain),
  103. Type: dnsmessage.TypeA,
  104. Class: dnsmessage.ClassINET,
  105. }
  106. qAAAA := dnsmessage.Question{
  107. Name: dnsmessage.MustNewName(domain),
  108. Type: dnsmessage.TypeAAAA,
  109. Class: dnsmessage.ClassINET,
  110. }
  111. var reqs []*dnsRequest
  112. now := time.Now()
  113. if option.IPv4Enable {
  114. msg := new(dnsmessage.Message)
  115. msg.Header.ID = reqIDGen()
  116. msg.Header.RecursionDesired = true
  117. msg.Questions = []dnsmessage.Question{qA}
  118. if reqOpts != nil {
  119. msg.Additionals = append(msg.Additionals, *reqOpts)
  120. }
  121. reqs = append(reqs, &dnsRequest{
  122. reqType: dnsmessage.TypeA,
  123. domain: domain,
  124. start: now,
  125. msg: msg,
  126. })
  127. }
  128. if option.IPv6Enable {
  129. msg := new(dnsmessage.Message)
  130. msg.Header.ID = reqIDGen()
  131. msg.Header.RecursionDesired = true
  132. msg.Questions = []dnsmessage.Question{qAAAA}
  133. if reqOpts != nil {
  134. msg.Additionals = append(msg.Additionals, *reqOpts)
  135. }
  136. reqs = append(reqs, &dnsRequest{
  137. reqType: dnsmessage.TypeAAAA,
  138. domain: domain,
  139. start: now,
  140. msg: msg,
  141. })
  142. }
  143. return reqs
  144. }
  145. // parseResponse parses DNS answers from the returned payload
  146. func parseResponse(payload []byte) (*IPRecord, error) {
  147. var parser dnsmessage.Parser
  148. h, err := parser.Start(payload)
  149. if err != nil {
  150. return nil, errors.New("failed to parse DNS response").Base(err).AtWarning()
  151. }
  152. if err := parser.SkipAllQuestions(); err != nil {
  153. return nil, errors.New("failed to skip questions in DNS response").Base(err).AtWarning()
  154. }
  155. now := time.Now()
  156. ipRecord := &IPRecord{
  157. ReqID: h.ID,
  158. RCode: h.RCode,
  159. Expire: now.Add(time.Second * 600),
  160. }
  161. L:
  162. for {
  163. ah, err := parser.AnswerHeader()
  164. if err != nil {
  165. if err != dnsmessage.ErrSectionDone {
  166. errors.LogInfoInner(context.Background(), err, "failed to parse answer section for domain: ", ah.Name.String())
  167. }
  168. break
  169. }
  170. ttl := ah.TTL
  171. if ttl == 0 {
  172. ttl = 600
  173. }
  174. expire := now.Add(time.Duration(ttl) * time.Second)
  175. if ipRecord.Expire.After(expire) {
  176. ipRecord.Expire = expire
  177. }
  178. switch ah.Type {
  179. case dnsmessage.TypeA:
  180. ans, err := parser.AResource()
  181. if err != nil {
  182. errors.LogInfoInner(context.Background(), err, "failed to parse A record for domain: ", ah.Name)
  183. break L
  184. }
  185. ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
  186. case dnsmessage.TypeAAAA:
  187. ans, err := parser.AAAAResource()
  188. if err != nil {
  189. errors.LogInfoInner(context.Background(), err, "failed to parse AAAA record for domain: ", ah.Name)
  190. break L
  191. }
  192. ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
  193. default:
  194. if err := parser.SkipAnswer(); err != nil {
  195. errors.LogInfoInner(context.Background(), err, "failed to skip answer")
  196. break L
  197. }
  198. continue
  199. }
  200. }
  201. return ipRecord, nil
  202. }
  203. // toDnsContext create a new background context with parent inbound, session and dns log
  204. func toDnsContext(ctx context.Context, addr string) context.Context {
  205. dnsCtx := core.ToBackgroundDetachedContext(ctx)
  206. if inbound := session.InboundFromContext(ctx); inbound != nil {
  207. dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
  208. }
  209. dnsCtx = session.ContextWithContent(dnsCtx, session.ContentFromContext(ctx))
  210. dnsCtx = log.ContextWithAccessMessage(dnsCtx, &log.AccessMessage{
  211. From: "DNS",
  212. To: addr,
  213. Status: log.AccessAccepted,
  214. Reason: "",
  215. })
  216. return dnsCtx
  217. }