dnscommon.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. package dns
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "math"
  6. "strings"
  7. "time"
  8. "github.com/xtls/xray-core/common"
  9. "github.com/xtls/xray-core/common/errors"
  10. "github.com/xtls/xray-core/common/log"
  11. "github.com/xtls/xray-core/common/net"
  12. "github.com/xtls/xray-core/common/session"
  13. "github.com/xtls/xray-core/core"
  14. dns_feature "github.com/xtls/xray-core/features/dns"
  15. "golang.org/x/net/dns/dnsmessage"
  16. )
  17. // Fqdn normalizes domain make sure it ends with '.'
  18. func Fqdn(domain string) string {
  19. if len(domain) > 0 && strings.HasSuffix(domain, ".") {
  20. return domain
  21. }
  22. return domain + "."
  23. }
  24. type record struct {
  25. A *IPRecord
  26. AAAA *IPRecord
  27. }
  28. // IPRecord is a cacheable item for a resolved domain
  29. type IPRecord struct {
  30. ReqID uint16
  31. IP []net.IP
  32. Expire time.Time
  33. RCode dnsmessage.RCode
  34. RawHeader *dnsmessage.Header
  35. }
  36. func (r *IPRecord) getIPs() ([]net.IP, int32, error) {
  37. if r == nil {
  38. return nil, 0, errRecordNotFound
  39. }
  40. untilExpire := time.Until(r.Expire).Seconds()
  41. ttl := int32(math.Ceil(untilExpire))
  42. if r.RCode != dnsmessage.RCodeSuccess {
  43. return nil, ttl, dns_feature.RCodeError(r.RCode)
  44. }
  45. if len(r.IP) == 0 {
  46. return nil, ttl, dns_feature.ErrEmptyResponse
  47. }
  48. return r.IP, ttl, nil
  49. }
  50. var errRecordNotFound = errors.New("record not found")
  51. type dnsRequest struct {
  52. reqType dnsmessage.Type
  53. domain string
  54. start time.Time
  55. expire time.Time
  56. msg *dnsmessage.Message
  57. }
  58. func genEDNS0Options(clientIP net.IP, padding int) *dnsmessage.Resource {
  59. if len(clientIP) == 0 && padding == 0 {
  60. return nil
  61. }
  62. const EDNS0SUBNET = 0x8
  63. const EDNS0PADDING = 0xc
  64. opt := new(dnsmessage.Resource)
  65. common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
  66. body := dnsmessage.OPTResource{}
  67. opt.Body = &body
  68. if len(clientIP) != 0 {
  69. var netmask int
  70. var family uint16
  71. if len(clientIP) == 4 {
  72. family = 1
  73. netmask = 24 // 24 for IPV4, 96 for IPv6
  74. } else {
  75. family = 2
  76. netmask = 96
  77. }
  78. b := make([]byte, 4)
  79. binary.BigEndian.PutUint16(b[0:], family)
  80. b[2] = byte(netmask)
  81. b[3] = 0
  82. switch family {
  83. case 1:
  84. ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
  85. needLength := (netmask + 8 - 1) / 8 // division rounding up
  86. b = append(b, ip[:needLength]...)
  87. case 2:
  88. ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
  89. needLength := (netmask + 8 - 1) / 8 // division rounding up
  90. b = append(b, ip[:needLength]...)
  91. }
  92. body.Options = append(body.Options,
  93. dnsmessage.Option{
  94. Code: EDNS0SUBNET,
  95. Data: b,
  96. })
  97. }
  98. if padding != 0 {
  99. body.Options = append(body.Options,
  100. dnsmessage.Option{
  101. Code: EDNS0PADDING,
  102. Data: make([]byte, padding),
  103. })
  104. }
  105. return opt
  106. }
  107. func buildReqMsgs(domain string, option dns_feature.IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
  108. qA := dnsmessage.Question{
  109. Name: dnsmessage.MustNewName(domain),
  110. Type: dnsmessage.TypeA,
  111. Class: dnsmessage.ClassINET,
  112. }
  113. qAAAA := dnsmessage.Question{
  114. Name: dnsmessage.MustNewName(domain),
  115. Type: dnsmessage.TypeAAAA,
  116. Class: dnsmessage.ClassINET,
  117. }
  118. var reqs []*dnsRequest
  119. now := time.Now()
  120. if option.IPv4Enable {
  121. msg := new(dnsmessage.Message)
  122. msg.Header.ID = reqIDGen()
  123. msg.Header.RecursionDesired = true
  124. msg.Questions = []dnsmessage.Question{qA}
  125. if reqOpts != nil {
  126. msg.Additionals = append(msg.Additionals, *reqOpts)
  127. }
  128. reqs = append(reqs, &dnsRequest{
  129. reqType: dnsmessage.TypeA,
  130. domain: domain,
  131. start: now,
  132. msg: msg,
  133. })
  134. }
  135. if option.IPv6Enable {
  136. msg := new(dnsmessage.Message)
  137. msg.Header.ID = reqIDGen()
  138. msg.Header.RecursionDesired = true
  139. msg.Questions = []dnsmessage.Question{qAAAA}
  140. if reqOpts != nil {
  141. msg.Additionals = append(msg.Additionals, *reqOpts)
  142. }
  143. reqs = append(reqs, &dnsRequest{
  144. reqType: dnsmessage.TypeAAAA,
  145. domain: domain,
  146. start: now,
  147. msg: msg,
  148. })
  149. }
  150. return reqs
  151. }
  152. // parseResponse parses DNS answers from the returned payload
  153. func parseResponse(payload []byte) (*IPRecord, error) {
  154. var parser dnsmessage.Parser
  155. h, err := parser.Start(payload)
  156. if err != nil {
  157. return nil, errors.New("failed to parse DNS response").Base(err).AtWarning()
  158. }
  159. if err := parser.SkipAllQuestions(); err != nil {
  160. return nil, errors.New("failed to skip questions in DNS response").Base(err).AtWarning()
  161. }
  162. now := time.Now()
  163. ipRecord := &IPRecord{
  164. ReqID: h.ID,
  165. RCode: h.RCode,
  166. Expire: now.Add(time.Second * dns_feature.DefaultTTL),
  167. RawHeader: &h,
  168. }
  169. L:
  170. for {
  171. ah, err := parser.AnswerHeader()
  172. if err != nil {
  173. if err != dnsmessage.ErrSectionDone {
  174. errors.LogInfoInner(context.Background(), err, "failed to parse answer section for domain: ", ah.Name.String())
  175. }
  176. break
  177. }
  178. ttl := ah.TTL
  179. if ttl == 0 {
  180. ttl = 1
  181. }
  182. expire := now.Add(time.Duration(ttl) * time.Second)
  183. if ipRecord.Expire.After(expire) {
  184. ipRecord.Expire = expire
  185. }
  186. switch ah.Type {
  187. case dnsmessage.TypeA:
  188. ans, err := parser.AResource()
  189. if err != nil {
  190. errors.LogInfoInner(context.Background(), err, "failed to parse A record for domain: ", ah.Name)
  191. break L
  192. }
  193. ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]).IP())
  194. case dnsmessage.TypeAAAA:
  195. ans, err := parser.AAAAResource()
  196. if err != nil {
  197. errors.LogInfoInner(context.Background(), err, "failed to parse AAAA record for domain: ", ah.Name)
  198. break L
  199. }
  200. newIP := net.IPAddress(ans.AAAA[:]).IP()
  201. if len(newIP) == net.IPv6len {
  202. ipRecord.IP = append(ipRecord.IP, newIP)
  203. }
  204. default:
  205. if err := parser.SkipAnswer(); err != nil {
  206. errors.LogInfoInner(context.Background(), err, "failed to skip answer")
  207. break L
  208. }
  209. continue
  210. }
  211. }
  212. return ipRecord, nil
  213. }
  214. // toDnsContext create a new background context with parent inbound, session and dns log
  215. func toDnsContext(ctx context.Context, addr string) context.Context {
  216. dnsCtx := core.ToBackgroundDetachedContext(ctx)
  217. if inbound := session.InboundFromContext(ctx); inbound != nil {
  218. dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
  219. }
  220. dnsCtx = session.ContextWithContent(dnsCtx, session.ContentFromContext(ctx))
  221. dnsCtx = log.ContextWithAccessMessage(dnsCtx, &log.AccessMessage{
  222. From: "DNS",
  223. To: addr,
  224. Status: log.AccessAccepted,
  225. Reason: "",
  226. })
  227. return dnsCtx
  228. }