dns_record.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. package option
  2. import (
  3. "encoding/base64"
  4. "strings"
  5. "github.com/sagernet/sing/common/buf"
  6. E "github.com/sagernet/sing/common/exceptions"
  7. "github.com/sagernet/sing/common/json"
  8. M "github.com/sagernet/sing/common/metadata"
  9. "github.com/miekg/dns"
  10. )
  11. const defaultDNSRecordTTL uint32 = 3600
  12. type DNSRCode int
  13. func (r DNSRCode) MarshalJSON() ([]byte, error) {
  14. rCodeValue, loaded := dns.RcodeToString[int(r)]
  15. if loaded {
  16. return json.Marshal(rCodeValue)
  17. }
  18. return json.Marshal(int(r))
  19. }
  20. func (r *DNSRCode) UnmarshalJSON(bytes []byte) error {
  21. var intValue int
  22. err := json.Unmarshal(bytes, &intValue)
  23. if err == nil {
  24. *r = DNSRCode(intValue)
  25. return nil
  26. }
  27. var stringValue string
  28. err = json.Unmarshal(bytes, &stringValue)
  29. if err != nil {
  30. return err
  31. }
  32. rCodeValue, loaded := dns.StringToRcode[stringValue]
  33. if !loaded {
  34. return E.New("unknown rcode: " + stringValue)
  35. }
  36. *r = DNSRCode(rCodeValue)
  37. return nil
  38. }
  39. func (r *DNSRCode) Build() int {
  40. if r == nil {
  41. return dns.RcodeSuccess
  42. }
  43. return int(*r)
  44. }
  45. type DNSRecordOptions struct {
  46. dns.RR
  47. fromBase64 bool
  48. }
  49. func (o DNSRecordOptions) MarshalJSON() ([]byte, error) {
  50. if o.fromBase64 {
  51. buffer := buf.Get(dns.Len(o.RR))
  52. defer buf.Put(buffer)
  53. offset, err := dns.PackRR(o.RR, buffer, 0, nil, false)
  54. if err != nil {
  55. return nil, err
  56. }
  57. return json.Marshal(base64.StdEncoding.EncodeToString(buffer[:offset]))
  58. }
  59. return json.Marshal(o.RR.String())
  60. }
  61. func (o *DNSRecordOptions) UnmarshalJSON(data []byte) error {
  62. var stringValue string
  63. err := json.Unmarshal(data, &stringValue)
  64. if err != nil {
  65. return err
  66. }
  67. binary, err := base64.StdEncoding.DecodeString(stringValue)
  68. if err == nil {
  69. return o.unmarshalBase64(binary)
  70. }
  71. record, err := parseDNSRecord(stringValue)
  72. if err != nil {
  73. return err
  74. }
  75. if record == nil {
  76. return E.New("empty DNS record")
  77. }
  78. if a, isA := record.(*dns.A); isA {
  79. a.A = M.AddrFromIP(a.A).Unmap().AsSlice()
  80. }
  81. o.RR = record
  82. return nil
  83. }
  84. func parseDNSRecord(stringValue string) (dns.RR, error) {
  85. if len(stringValue) > 0 && stringValue[len(stringValue)-1] != '\n' {
  86. stringValue += "\n"
  87. }
  88. parser := dns.NewZoneParser(strings.NewReader(stringValue), "", "")
  89. parser.SetDefaultTTL(defaultDNSRecordTTL)
  90. record, _ := parser.Next()
  91. return record, parser.Err()
  92. }
  93. func (o *DNSRecordOptions) unmarshalBase64(binary []byte) error {
  94. record, _, err := dns.UnpackRR(binary, 0)
  95. if err != nil {
  96. return E.New("parse binary DNS record")
  97. }
  98. o.RR = record
  99. o.fromBase64 = true
  100. return nil
  101. }
  102. func (o DNSRecordOptions) Build() dns.RR {
  103. return o.RR
  104. }
  105. func (o DNSRecordOptions) Match(record dns.RR) bool {
  106. if o.RR == nil || record == nil {
  107. return false
  108. }
  109. return dns.IsDuplicate(o.RR, record)
  110. }