client.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. package dns
  2. import (
  3. "context"
  4. "net"
  5. "net/netip"
  6. "time"
  7. "github.com/sagernet/sing/common"
  8. "github.com/sagernet/sing/common/cache"
  9. E "github.com/sagernet/sing/common/exceptions"
  10. "github.com/sagernet/sing/common/task"
  11. "github.com/sagernet/sing-box/adapter"
  12. C "github.com/sagernet/sing-box/constant"
  13. "golang.org/x/net/dns/dnsmessage"
  14. )
  15. const DefaultTTL = 600
  16. var (
  17. ErrNoRawSupport = E.New("no raw query support by current transport")
  18. ErrNotCached = E.New("not cached")
  19. )
  20. var _ adapter.DNSClient = (*Client)(nil)
  21. type Client struct {
  22. cache *cache.LruCache[dnsmessage.Question, dnsmessage.Message]
  23. }
  24. func NewClient() *Client {
  25. return &Client{
  26. cache: cache.New[dnsmessage.Question, dnsmessage.Message](),
  27. }
  28. }
  29. func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message) (*dnsmessage.Message, error) {
  30. if len(message.Questions) == 0 {
  31. return nil, E.New("empty query")
  32. }
  33. question := message.Questions[0]
  34. cachedAnswer, cached := c.cache.Load(question)
  35. if cached {
  36. cachedAnswer.ID = message.ID
  37. return &cachedAnswer, nil
  38. }
  39. if !transport.Raw() {
  40. if question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA {
  41. return c.exchangeToLookup(ctx, transport, message, question)
  42. }
  43. return nil, ErrNoRawSupport
  44. }
  45. response, err := transport.Exchange(ctx, message)
  46. if err != nil {
  47. return nil, err
  48. }
  49. c.cache.StoreWithExpire(question, *response, calculateExpire(message))
  50. return message, err
  51. }
  52. func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, strategy C.DomainStrategy) ([]netip.Addr, error) {
  53. dnsName, err := dnsmessage.NewName(domain)
  54. if err != nil {
  55. return nil, wrapError(err)
  56. }
  57. if transport.Raw() {
  58. if strategy == C.DomainStrategyUseIPv4 {
  59. return c.lookupToExchange(ctx, transport, dnsName, dnsmessage.TypeA)
  60. } else if strategy == C.DomainStrategyUseIPv6 {
  61. return c.lookupToExchange(ctx, transport, dnsName, dnsmessage.TypeAAAA)
  62. }
  63. var response4 []netip.Addr
  64. var response6 []netip.Addr
  65. err = task.Run(ctx, func() error {
  66. response, err := c.lookupToExchange(ctx, transport, dnsName, dnsmessage.TypeA)
  67. if err != nil {
  68. return err
  69. }
  70. response4 = response
  71. return nil
  72. }, func() error {
  73. response, err := c.lookupToExchange(ctx, transport, dnsName, dnsmessage.TypeAAAA)
  74. if err != nil {
  75. return err
  76. }
  77. response6 = response
  78. return nil
  79. })
  80. if len(response4) == 0 && len(response6) == 0 {
  81. return nil, err
  82. }
  83. return sortAddresses(response4, response6, strategy), nil
  84. }
  85. if strategy == C.DomainStrategyUseIPv4 {
  86. response, err := c.questionCache(dnsmessage.Question{
  87. Name: dnsName,
  88. Type: dnsmessage.TypeA,
  89. Class: dnsmessage.ClassINET,
  90. })
  91. if err != ErrNotCached {
  92. return response, err
  93. }
  94. } else if strategy == C.DomainStrategyUseIPv6 {
  95. response, err := c.questionCache(dnsmessage.Question{
  96. Name: dnsName,
  97. Type: dnsmessage.TypeAAAA,
  98. Class: dnsmessage.ClassINET,
  99. })
  100. if err != ErrNotCached {
  101. return response, err
  102. }
  103. } else {
  104. response4, _ := c.questionCache(dnsmessage.Question{
  105. Name: dnsName,
  106. Type: dnsmessage.TypeA,
  107. Class: dnsmessage.ClassINET,
  108. })
  109. response6, _ := c.questionCache(dnsmessage.Question{
  110. Name: dnsName,
  111. Type: dnsmessage.TypeAAAA,
  112. Class: dnsmessage.ClassINET,
  113. })
  114. if len(response4) > 0 || len(response6) > 0 {
  115. return sortAddresses(response4, response6, strategy), nil
  116. }
  117. }
  118. var rCode dnsmessage.RCode
  119. response, err := transport.Lookup(ctx, domain, strategy)
  120. if err != nil {
  121. err = wrapError(err)
  122. if rCodeError, isRCodeError := err.(RCodeError); !isRCodeError {
  123. return nil, err
  124. } else {
  125. rCode = dnsmessage.RCode(rCodeError)
  126. }
  127. }
  128. header := dnsmessage.Header{
  129. Response: true,
  130. Authoritative: true,
  131. RCode: rCode,
  132. }
  133. expire := time.Now().Add(time.Second * time.Duration(DefaultTTL))
  134. if strategy != C.DomainStrategyUseIPv6 {
  135. question4 := dnsmessage.Question{
  136. Name: dnsName,
  137. Type: dnsmessage.TypeA,
  138. Class: dnsmessage.ClassINET,
  139. }
  140. response4 := common.Filter(response, func(addr netip.Addr) bool {
  141. return addr.Is4() || addr.Is4In6()
  142. })
  143. message4 := dnsmessage.Message{
  144. Header: header,
  145. Questions: []dnsmessage.Question{question4},
  146. }
  147. if len(response4) > 0 {
  148. for _, address := range response4 {
  149. message4.Answers = append(message4.Answers, dnsmessage.Resource{
  150. Header: dnsmessage.ResourceHeader{
  151. Name: question4.Name,
  152. Class: question4.Class,
  153. TTL: DefaultTTL,
  154. },
  155. Body: &dnsmessage.AResource{
  156. A: address.As4(),
  157. },
  158. })
  159. }
  160. }
  161. c.cache.StoreWithExpire(question4, message4, expire)
  162. }
  163. if strategy != C.DomainStrategyUseIPv4 {
  164. question6 := dnsmessage.Question{
  165. Name: dnsName,
  166. Type: dnsmessage.TypeAAAA,
  167. Class: dnsmessage.ClassINET,
  168. }
  169. response6 := common.Filter(response, func(addr netip.Addr) bool {
  170. return addr.Is6() && !addr.Is4In6()
  171. })
  172. message6 := dnsmessage.Message{
  173. Header: header,
  174. Questions: []dnsmessage.Question{question6},
  175. }
  176. if len(response6) > 0 {
  177. for _, address := range response6 {
  178. message6.Answers = append(message6.Answers, dnsmessage.Resource{
  179. Header: dnsmessage.ResourceHeader{
  180. Name: question6.Name,
  181. Class: question6.Class,
  182. TTL: DefaultTTL,
  183. },
  184. Body: &dnsmessage.AAAAResource{
  185. AAAA: address.As16(),
  186. },
  187. })
  188. }
  189. }
  190. c.cache.StoreWithExpire(question6, message6, expire)
  191. }
  192. return response, err
  193. }
  194. func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy C.DomainStrategy) []netip.Addr {
  195. if strategy == C.DomainStrategyPreferIPv6 {
  196. return append(response6, response4...)
  197. } else {
  198. return append(response4, response6...)
  199. }
  200. }
  201. func calculateExpire(message *dnsmessage.Message) time.Time {
  202. timeToLive := DefaultTTL
  203. for _, answer := range message.Answers {
  204. if int(answer.Header.TTL) < timeToLive {
  205. timeToLive = int(answer.Header.TTL)
  206. }
  207. }
  208. return time.Now().Add(time.Second * time.Duration(timeToLive))
  209. }
  210. func (c *Client) exchangeToLookup(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message, question dnsmessage.Question) (*dnsmessage.Message, error) {
  211. domain := question.Name.String()
  212. var strategy C.DomainStrategy
  213. if question.Type == dnsmessage.TypeA {
  214. strategy = C.DomainStrategyUseIPv4
  215. } else {
  216. strategy = C.DomainStrategyUseIPv6
  217. }
  218. var rCode dnsmessage.RCode
  219. result, err := c.Lookup(ctx, transport, domain, strategy)
  220. if err != nil {
  221. err = wrapError(err)
  222. if rCodeError, isRCodeError := err.(RCodeError); !isRCodeError {
  223. return nil, err
  224. } else {
  225. rCode = dnsmessage.RCode(rCodeError)
  226. }
  227. }
  228. response := dnsmessage.Message{
  229. Header: dnsmessage.Header{
  230. ID: message.ID,
  231. RCode: rCode,
  232. RecursionAvailable: true,
  233. RecursionDesired: true,
  234. Response: true,
  235. },
  236. Questions: message.Questions,
  237. }
  238. for _, address := range result {
  239. var resource dnsmessage.Resource
  240. resource.Header = dnsmessage.ResourceHeader{
  241. Name: question.Name,
  242. Class: question.Class,
  243. TTL: DefaultTTL,
  244. }
  245. if address.Is4() || address.Is4In6() {
  246. resource.Body = &dnsmessage.AResource{
  247. A: address.As4(),
  248. }
  249. } else {
  250. resource.Body = &dnsmessage.AAAAResource{
  251. AAAA: address.As16(),
  252. }
  253. }
  254. }
  255. return &response, nil
  256. }
  257. func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name dnsmessage.Name, qType dnsmessage.Type) ([]netip.Addr, error) {
  258. question := dnsmessage.Question{
  259. Name: name,
  260. Type: qType,
  261. Class: dnsmessage.ClassINET,
  262. }
  263. cachedAddresses, err := c.questionCache(question)
  264. if err != ErrNotCached {
  265. return cachedAddresses, err
  266. }
  267. message := dnsmessage.Message{
  268. Header: dnsmessage.Header{
  269. ID: 0,
  270. RecursionDesired: true,
  271. },
  272. Questions: []dnsmessage.Question{question},
  273. }
  274. response, err := c.Exchange(ctx, transport, &message)
  275. if err != nil {
  276. return nil, err
  277. }
  278. return messageToAddresses(response)
  279. }
  280. func (c *Client) questionCache(question dnsmessage.Question) ([]netip.Addr, error) {
  281. response, cached := c.cache.Load(question)
  282. if !cached {
  283. return nil, ErrNotCached
  284. }
  285. return messageToAddresses(&response)
  286. }
  287. func messageToAddresses(response *dnsmessage.Message) ([]netip.Addr, error) {
  288. if response.RCode != dnsmessage.RCodeSuccess {
  289. return nil, RCodeError(response.RCode)
  290. }
  291. addresses := make([]netip.Addr, 0, len(response.Answers))
  292. for _, answer := range response.Answers {
  293. switch resource := answer.Body.(type) {
  294. case *dnsmessage.AResource:
  295. addresses = append(addresses, netip.AddrFrom4(resource.A))
  296. case *dnsmessage.AAAAResource:
  297. addresses = append(addresses, netip.AddrFrom16(resource.AAAA))
  298. }
  299. }
  300. return addresses, nil
  301. }
  302. func wrapError(err error) error {
  303. if dnsErr, isDNSError := err.(*net.DNSError); isDNSError {
  304. if dnsErr.IsNotFound {
  305. return RCodeNameError
  306. }
  307. }
  308. return err
  309. }