dnscommon.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. package dns
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "strings"
  6. "time"
  7. "github.com/xtls/xray-core/common"
  8. "github.com/xtls/xray-core/common/errors"
  9. "github.com/xtls/xray-core/common/log"
  10. "github.com/xtls/xray-core/common/net"
  11. "github.com/xtls/xray-core/common/session"
  12. "github.com/xtls/xray-core/core"
  13. dns_feature "github.com/xtls/xray-core/features/dns"
  14. "golang.org/x/net/dns/dnsmessage"
  15. )
  16. // Fqdn normalizes domain make sure it ends with '.'
  17. func Fqdn(domain string) string {
  18. if len(domain) > 0 && strings.HasSuffix(domain, ".") {
  19. return domain
  20. }
  21. return domain + "."
  22. }
  23. type record struct {
  24. A *IPRecord
  25. AAAA *IPRecord
  26. }
  27. // IPRecord is a cacheable item for a resolved domain
  28. type IPRecord struct {
  29. ReqID uint16
  30. IP []net.IP
  31. Expire time.Time
  32. RCode dnsmessage.RCode
  33. RawHeader *dnsmessage.Header
  34. }
  35. func (r *IPRecord) getIPs() ([]net.IP, uint32, error) {
  36. if r == nil {
  37. return nil, 0, errRecordNotFound
  38. }
  39. untilExpire := time.Until(r.Expire).Seconds()
  40. if untilExpire <= 0 {
  41. return nil, 0, errRecordNotFound
  42. }
  43. ttl := uint32(untilExpire) + 1
  44. if ttl == 1 {
  45. r.Expire = time.Now().Add(time.Second) // To ensure that two consecutive requests get the same result
  46. }
  47. if r.RCode != dnsmessage.RCodeSuccess {
  48. return nil, ttl, dns_feature.RCodeError(r.RCode)
  49. }
  50. if len(r.IP) == 0 {
  51. return nil, ttl, dns_feature.ErrEmptyResponse
  52. }
  53. return r.IP, ttl, nil
  54. }
  55. var errRecordNotFound = errors.New("record not found")
  56. type dnsRequest struct {
  57. reqType dnsmessage.Type
  58. domain string
  59. start time.Time
  60. expire time.Time
  61. msg *dnsmessage.Message
  62. }
  63. func genEDNS0Options(clientIP net.IP, padding int) *dnsmessage.Resource {
  64. if len(clientIP) == 0 && padding == 0 {
  65. return nil
  66. }
  67. const EDNS0SUBNET = 0x8
  68. const EDNS0PADDING = 0xc
  69. opt := new(dnsmessage.Resource)
  70. common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
  71. body := dnsmessage.OPTResource{}
  72. opt.Body = &body
  73. if len(clientIP) != 0 {
  74. var netmask int
  75. var family uint16
  76. if len(clientIP) == 4 {
  77. family = 1
  78. netmask = 24 // 24 for IPV4, 96 for IPv6
  79. } else {
  80. family = 2
  81. netmask = 96
  82. }
  83. b := make([]byte, 4)
  84. binary.BigEndian.PutUint16(b[0:], family)
  85. b[2] = byte(netmask)
  86. b[3] = 0
  87. switch family {
  88. case 1:
  89. ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
  90. needLength := (netmask + 8 - 1) / 8 // division rounding up
  91. b = append(b, ip[:needLength]...)
  92. case 2:
  93. ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
  94. needLength := (netmask + 8 - 1) / 8 // division rounding up
  95. b = append(b, ip[:needLength]...)
  96. }
  97. body.Options = append(body.Options,
  98. dnsmessage.Option{
  99. Code: EDNS0SUBNET,
  100. Data: b,
  101. })
  102. }
  103. if padding != 0 {
  104. body.Options = append(body.Options,
  105. dnsmessage.Option{
  106. Code: EDNS0PADDING,
  107. Data: make([]byte, padding),
  108. })
  109. }
  110. return opt
  111. }
  112. func buildReqMsgs(domain string, option dns_feature.IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
  113. qA := dnsmessage.Question{
  114. Name: dnsmessage.MustNewName(domain),
  115. Type: dnsmessage.TypeA,
  116. Class: dnsmessage.ClassINET,
  117. }
  118. qAAAA := dnsmessage.Question{
  119. Name: dnsmessage.MustNewName(domain),
  120. Type: dnsmessage.TypeAAAA,
  121. Class: dnsmessage.ClassINET,
  122. }
  123. var reqs []*dnsRequest
  124. now := time.Now()
  125. if option.IPv4Enable {
  126. msg := new(dnsmessage.Message)
  127. msg.Header.ID = reqIDGen()
  128. msg.Header.RecursionDesired = true
  129. msg.Questions = []dnsmessage.Question{qA}
  130. if reqOpts != nil {
  131. msg.Additionals = append(msg.Additionals, *reqOpts)
  132. }
  133. reqs = append(reqs, &dnsRequest{
  134. reqType: dnsmessage.TypeA,
  135. domain: domain,
  136. start: now,
  137. msg: msg,
  138. })
  139. }
  140. if option.IPv6Enable {
  141. msg := new(dnsmessage.Message)
  142. msg.Header.ID = reqIDGen()
  143. msg.Header.RecursionDesired = true
  144. msg.Questions = []dnsmessage.Question{qAAAA}
  145. if reqOpts != nil {
  146. msg.Additionals = append(msg.Additionals, *reqOpts)
  147. }
  148. reqs = append(reqs, &dnsRequest{
  149. reqType: dnsmessage.TypeAAAA,
  150. domain: domain,
  151. start: now,
  152. msg: msg,
  153. })
  154. }
  155. return reqs
  156. }
  157. // parseResponse parses DNS answers from the returned payload
  158. func parseResponse(payload []byte) (*IPRecord, error) {
  159. var parser dnsmessage.Parser
  160. h, err := parser.Start(payload)
  161. if err != nil {
  162. return nil, errors.New("failed to parse DNS response").Base(err).AtWarning()
  163. }
  164. if err := parser.SkipAllQuestions(); err != nil {
  165. return nil, errors.New("failed to skip questions in DNS response").Base(err).AtWarning()
  166. }
  167. now := time.Now()
  168. ipRecord := &IPRecord{
  169. ReqID: h.ID,
  170. RCode: h.RCode,
  171. Expire: now.Add(time.Second * dns_feature.DefaultTTL),
  172. RawHeader: &h,
  173. }
  174. L:
  175. for {
  176. ah, err := parser.AnswerHeader()
  177. if err != nil {
  178. if err != dnsmessage.ErrSectionDone {
  179. errors.LogInfoInner(context.Background(), err, "failed to parse answer section for domain: ", ah.Name.String())
  180. }
  181. break
  182. }
  183. ttl := ah.TTL
  184. if ttl == 0 {
  185. ttl = 1
  186. }
  187. expire := now.Add(time.Duration(ttl) * time.Second)
  188. if ipRecord.Expire.After(expire) {
  189. ipRecord.Expire = expire
  190. }
  191. switch ah.Type {
  192. case dnsmessage.TypeA:
  193. ans, err := parser.AResource()
  194. if err != nil {
  195. errors.LogInfoInner(context.Background(), err, "failed to parse A record for domain: ", ah.Name)
  196. break L
  197. }
  198. ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]).IP())
  199. case dnsmessage.TypeAAAA:
  200. ans, err := parser.AAAAResource()
  201. if err != nil {
  202. errors.LogInfoInner(context.Background(), err, "failed to parse AAAA record for domain: ", ah.Name)
  203. break L
  204. }
  205. newIP := net.IPAddress(ans.AAAA[:]).IP()
  206. if len(newIP) == net.IPv6len {
  207. ipRecord.IP = append(ipRecord.IP, newIP)
  208. }
  209. default:
  210. if err := parser.SkipAnswer(); err != nil {
  211. errors.LogInfoInner(context.Background(), err, "failed to skip answer")
  212. break L
  213. }
  214. continue
  215. }
  216. }
  217. return ipRecord, nil
  218. }
  219. // toDnsContext create a new background context with parent inbound, session and dns log
  220. func toDnsContext(ctx context.Context, addr string) context.Context {
  221. dnsCtx := core.ToBackgroundDetachedContext(ctx)
  222. if inbound := session.InboundFromContext(ctx); inbound != nil {
  223. dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
  224. }
  225. dnsCtx = session.ContextWithContent(dnsCtx, session.ContentFromContext(ctx))
  226. dnsCtx = log.ContextWithAccessMessage(dnsCtx, &log.AccessMessage{
  227. From: "DNS",
  228. To: addr,
  229. Status: log.AccessAccepted,
  230. Reason: "",
  231. })
  232. return dnsCtx
  233. }