1
0

dnscommon.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. package dns
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "math"
  6. "strings"
  7. "time"
  8. "github.com/xtls/xray-core/common"
  9. "github.com/xtls/xray-core/common/errors"
  10. "github.com/xtls/xray-core/common/log"
  11. "github.com/xtls/xray-core/common/net"
  12. "github.com/xtls/xray-core/common/session"
  13. "github.com/xtls/xray-core/core"
  14. dns_feature "github.com/xtls/xray-core/features/dns"
  15. "golang.org/x/net/dns/dnsmessage"
  16. )
  17. // Fqdn normalizes domain make sure it ends with '.'
  18. // case-sensitive
  19. func Fqdn(domain string) string {
  20. if len(domain) > 0 && strings.HasSuffix(domain, ".") {
  21. return domain
  22. }
  23. return domain + "."
  24. }
  25. type record struct {
  26. A *IPRecord
  27. AAAA *IPRecord
  28. }
  29. // IPRecord is a cacheable item for a resolved domain
  30. type IPRecord struct {
  31. ReqID uint16
  32. IP []net.IP
  33. Expire time.Time
  34. RCode dnsmessage.RCode
  35. RawHeader *dnsmessage.Header
  36. }
  37. func (r *IPRecord) getIPs() ([]net.IP, int32, error) {
  38. if r == nil {
  39. return nil, 0, errRecordNotFound
  40. }
  41. untilExpire := time.Until(r.Expire).Seconds()
  42. ttl := int32(math.Ceil(untilExpire))
  43. if r.RCode != dnsmessage.RCodeSuccess {
  44. return nil, ttl, dns_feature.RCodeError(r.RCode)
  45. }
  46. if len(r.IP) == 0 {
  47. return nil, ttl, dns_feature.ErrEmptyResponse
  48. }
  49. return r.IP, ttl, nil
  50. }
  51. var errRecordNotFound = errors.New("record not found")
  52. type dnsRequest struct {
  53. reqType dnsmessage.Type
  54. domain string
  55. start time.Time
  56. expire time.Time
  57. msg *dnsmessage.Message
  58. }
  59. func genEDNS0Options(clientIP net.IP, padding int) *dnsmessage.Resource {
  60. if len(clientIP) == 0 && padding == 0 {
  61. return nil
  62. }
  63. const EDNS0SUBNET = 0x8
  64. const EDNS0PADDING = 0xc
  65. opt := new(dnsmessage.Resource)
  66. common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
  67. body := dnsmessage.OPTResource{}
  68. opt.Body = &body
  69. if len(clientIP) != 0 {
  70. var netmask int
  71. var family uint16
  72. if len(clientIP) == 4 {
  73. family = 1
  74. netmask = 24 // 24 for IPV4, 96 for IPv6
  75. } else {
  76. family = 2
  77. netmask = 96
  78. }
  79. b := make([]byte, 4)
  80. binary.BigEndian.PutUint16(b[0:], family)
  81. b[2] = byte(netmask)
  82. b[3] = 0
  83. switch family {
  84. case 1:
  85. ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
  86. needLength := (netmask + 8 - 1) / 8 // division rounding up
  87. b = append(b, ip[:needLength]...)
  88. case 2:
  89. ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
  90. needLength := (netmask + 8 - 1) / 8 // division rounding up
  91. b = append(b, ip[:needLength]...)
  92. }
  93. body.Options = append(body.Options,
  94. dnsmessage.Option{
  95. Code: EDNS0SUBNET,
  96. Data: b,
  97. })
  98. }
  99. if padding != 0 {
  100. body.Options = append(body.Options,
  101. dnsmessage.Option{
  102. Code: EDNS0PADDING,
  103. Data: make([]byte, padding),
  104. })
  105. }
  106. return opt
  107. }
  108. func buildReqMsgs(domain string, option dns_feature.IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
  109. qA := dnsmessage.Question{
  110. Name: dnsmessage.MustNewName(domain),
  111. Type: dnsmessage.TypeA,
  112. Class: dnsmessage.ClassINET,
  113. }
  114. qAAAA := dnsmessage.Question{
  115. Name: dnsmessage.MustNewName(domain),
  116. Type: dnsmessage.TypeAAAA,
  117. Class: dnsmessage.ClassINET,
  118. }
  119. var reqs []*dnsRequest
  120. now := time.Now()
  121. if option.IPv4Enable {
  122. msg := new(dnsmessage.Message)
  123. msg.Header.ID = reqIDGen()
  124. msg.Header.RecursionDesired = true
  125. msg.Questions = []dnsmessage.Question{qA}
  126. if reqOpts != nil {
  127. msg.Additionals = append(msg.Additionals, *reqOpts)
  128. }
  129. reqs = append(reqs, &dnsRequest{
  130. reqType: dnsmessage.TypeA,
  131. domain: domain,
  132. start: now,
  133. msg: msg,
  134. })
  135. }
  136. if option.IPv6Enable {
  137. msg := new(dnsmessage.Message)
  138. msg.Header.ID = reqIDGen()
  139. msg.Header.RecursionDesired = true
  140. msg.Questions = []dnsmessage.Question{qAAAA}
  141. if reqOpts != nil {
  142. msg.Additionals = append(msg.Additionals, *reqOpts)
  143. }
  144. reqs = append(reqs, &dnsRequest{
  145. reqType: dnsmessage.TypeAAAA,
  146. domain: domain,
  147. start: now,
  148. msg: msg,
  149. })
  150. }
  151. return reqs
  152. }
  153. // parseResponse parses DNS answers from the returned payload
  154. func parseResponse(payload []byte) (*IPRecord, error) {
  155. var parser dnsmessage.Parser
  156. h, err := parser.Start(payload)
  157. if err != nil {
  158. return nil, errors.New("failed to parse DNS response").Base(err).AtWarning()
  159. }
  160. if err := parser.SkipAllQuestions(); err != nil {
  161. return nil, errors.New("failed to skip questions in DNS response").Base(err).AtWarning()
  162. }
  163. now := time.Now()
  164. ipRecord := &IPRecord{
  165. ReqID: h.ID,
  166. RCode: h.RCode,
  167. Expire: now.Add(time.Second * dns_feature.DefaultTTL),
  168. RawHeader: &h,
  169. }
  170. L:
  171. for {
  172. ah, err := parser.AnswerHeader()
  173. if err != nil {
  174. if err != dnsmessage.ErrSectionDone {
  175. errors.LogInfoInner(context.Background(), err, "failed to parse answer section for domain: ", ah.Name.String())
  176. }
  177. break
  178. }
  179. ttl := ah.TTL
  180. if ttl == 0 {
  181. ttl = 1
  182. }
  183. expire := now.Add(time.Duration(ttl) * time.Second)
  184. if ipRecord.Expire.After(expire) {
  185. ipRecord.Expire = expire
  186. }
  187. switch ah.Type {
  188. case dnsmessage.TypeA:
  189. ans, err := parser.AResource()
  190. if err != nil {
  191. errors.LogInfoInner(context.Background(), err, "failed to parse A record for domain: ", ah.Name)
  192. break L
  193. }
  194. ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]).IP())
  195. case dnsmessage.TypeAAAA:
  196. ans, err := parser.AAAAResource()
  197. if err != nil {
  198. errors.LogInfoInner(context.Background(), err, "failed to parse AAAA record for domain: ", ah.Name)
  199. break L
  200. }
  201. newIP := net.IPAddress(ans.AAAA[:]).IP()
  202. if len(newIP) == net.IPv6len {
  203. ipRecord.IP = append(ipRecord.IP, newIP)
  204. }
  205. default:
  206. if err := parser.SkipAnswer(); err != nil {
  207. errors.LogInfoInner(context.Background(), err, "failed to skip answer")
  208. break L
  209. }
  210. continue
  211. }
  212. }
  213. return ipRecord, nil
  214. }
  215. // toDnsContext create a new background context with parent inbound, session and dns log
  216. func toDnsContext(ctx context.Context, addr string) context.Context {
  217. dnsCtx := core.ToBackgroundDetachedContext(ctx)
  218. if inbound := session.InboundFromContext(ctx); inbound != nil {
  219. dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
  220. }
  221. dnsCtx = session.ContextWithContent(dnsCtx, session.ContentFromContext(ctx))
  222. dnsCtx = log.ContextWithAccessMessage(dnsCtx, &log.AccessMessage{
  223. From: "DNS",
  224. To: addr,
  225. Status: log.AccessAccepted,
  226. Reason: "",
  227. })
  228. return dnsCtx
  229. }