Browse Source

Add query_type DNS rule item

世界 2 years ago
parent
commit
687b4509df

+ 4 - 0
adapter/inbound.go

@@ -46,6 +46,10 @@ type InboundContext struct {
 	SourceGeoIPCode      string
 	GeoIPCode            string
 	ProcessInfo          *process.Info
+
+	// dns cache
+
+	QueryType uint16
 }
 
 type inboundContextKey struct{}

+ 9 - 0
docs/configuration/dns/rule.md

@@ -9,6 +9,11 @@
           "mixed-in"
         ],
         "ip_version": 6,
+        "query_type": [
+          "A",
+          "HTTPS",
+          32768
+        ],
         "network": "tcp",
         "auth_user": [
           "usera",
@@ -119,6 +124,10 @@ Tags of [Inbound](/configuration/inbound).
 
 Not limited if empty.
 
+#### query_type
+
+DNS query type. Values can be integers or type name strings.
+
 #### network
 
 `tcp` or `udp`.

+ 9 - 0
docs/configuration/dns/rule.zh.md

@@ -9,6 +9,11 @@
           "mixed-in"
         ],
         "ip_version": 6,
+        "query_type": [
+          "A",
+          "HTTPS",
+          32768
+        ],
         "network": "tcp",
         "auth_user": [
           "usera",
@@ -118,6 +123,10 @@
 
 默认不限制。
 
+#### query_type
+
+DNS 查询类型。值可以为整数或者类型名称字符串。
+
 #### network
 
 `tcp` 或 `udp`。

+ 27 - 26
option/dns.go

@@ -77,32 +77,33 @@ func (r *DNSRule) UnmarshalJSON(bytes []byte) error {
 }
 
 type DefaultDNSRule struct {
-	Inbound         Listable[string] `json:"inbound,omitempty"`
-	IPVersion       int              `json:"ip_version,omitempty"`
-	Network         string           `json:"network,omitempty"`
-	AuthUser        Listable[string] `json:"auth_user,omitempty"`
-	Protocol        Listable[string] `json:"protocol,omitempty"`
-	Domain          Listable[string] `json:"domain,omitempty"`
-	DomainSuffix    Listable[string] `json:"domain_suffix,omitempty"`
-	DomainKeyword   Listable[string] `json:"domain_keyword,omitempty"`
-	DomainRegex     Listable[string] `json:"domain_regex,omitempty"`
-	Geosite         Listable[string] `json:"geosite,omitempty"`
-	SourceGeoIP     Listable[string] `json:"source_geoip,omitempty"`
-	SourceIPCIDR    Listable[string] `json:"source_ip_cidr,omitempty"`
-	SourcePort      Listable[uint16] `json:"source_port,omitempty"`
-	SourcePortRange Listable[string] `json:"source_port_range,omitempty"`
-	Port            Listable[uint16] `json:"port,omitempty"`
-	PortRange       Listable[string] `json:"port_range,omitempty"`
-	ProcessName     Listable[string] `json:"process_name,omitempty"`
-	ProcessPath     Listable[string] `json:"process_path,omitempty"`
-	PackageName     Listable[string] `json:"package_name,omitempty"`
-	User            Listable[string] `json:"user,omitempty"`
-	UserID          Listable[int32]  `json:"user_id,omitempty"`
-	Outbound        Listable[string] `json:"outbound,omitempty"`
-	ClashMode       string           `json:"clash_mode,omitempty"`
-	Invert          bool             `json:"invert,omitempty"`
-	Server          string           `json:"server,omitempty"`
-	DisableCache    bool             `json:"disable_cache,omitempty"`
+	Inbound         Listable[string]       `json:"inbound,omitempty"`
+	IPVersion       int                    `json:"ip_version,omitempty"`
+	QueryType       Listable[DNSQueryType] `json:"query_type,omitempty"`
+	Network         string                 `json:"network,omitempty"`
+	AuthUser        Listable[string]       `json:"auth_user,omitempty"`
+	Protocol        Listable[string]       `json:"protocol,omitempty"`
+	Domain          Listable[string]       `json:"domain,omitempty"`
+	DomainSuffix    Listable[string]       `json:"domain_suffix,omitempty"`
+	DomainKeyword   Listable[string]       `json:"domain_keyword,omitempty"`
+	DomainRegex     Listable[string]       `json:"domain_regex,omitempty"`
+	Geosite         Listable[string]       `json:"geosite,omitempty"`
+	SourceGeoIP     Listable[string]       `json:"source_geoip,omitempty"`
+	SourceIPCIDR    Listable[string]       `json:"source_ip_cidr,omitempty"`
+	SourcePort      Listable[uint16]       `json:"source_port,omitempty"`
+	SourcePortRange Listable[string]       `json:"source_port_range,omitempty"`
+	Port            Listable[uint16]       `json:"port,omitempty"`
+	PortRange       Listable[string]       `json:"port_range,omitempty"`
+	ProcessName     Listable[string]       `json:"process_name,omitempty"`
+	ProcessPath     Listable[string]       `json:"process_path,omitempty"`
+	PackageName     Listable[string]       `json:"package_name,omitempty"`
+	User            Listable[string]       `json:"user,omitempty"`
+	UserID          Listable[int32]        `json:"user_id,omitempty"`
+	Outbound        Listable[string]       `json:"outbound,omitempty"`
+	ClashMode       string                 `json:"clash_mode,omitempty"`
+	Invert          bool                   `json:"invert,omitempty"`
+	Server          string                 `json:"server,omitempty"`
+	DisableCache    bool                   `json:"disable_cache,omitempty"`
 }
 
 func (r DefaultDNSRule) IsValid() bool {

+ 40 - 0
option/types.go

@@ -8,7 +8,10 @@ import (
 	"github.com/sagernet/sing-box/common/json"
 	"github.com/sagernet/sing-dns"
 	E "github.com/sagernet/sing/common/exceptions"
+	F "github.com/sagernet/sing/common/format"
 	N "github.com/sagernet/sing/common/network"
+
+	mDNS "github.com/miekg/dns"
 )
 
 type ListenAddress netip.Addr
@@ -187,3 +190,40 @@ func (p *ListenPrefix) UnmarshalJSON(bytes []byte) error {
 func (p ListenPrefix) Build() netip.Prefix {
 	return netip.Prefix(p)
 }
+
+type DNSQueryType uint16
+
+func (t DNSQueryType) MarshalJSON() ([]byte, error) {
+	typeName, loaded := mDNS.TypeToString[uint16(t)]
+	if loaded {
+		return json.Marshal(typeName)
+	}
+	return json.Marshal(uint16(t))
+}
+
+func (t *DNSQueryType) UnmarshalJSON(bytes []byte) error {
+	var valueNumber uint16
+	err := json.Unmarshal(bytes, &valueNumber)
+	if err == nil {
+		*t = DNSQueryType(valueNumber)
+		return nil
+	}
+	var valueString string
+	err = json.Unmarshal(bytes, &valueString)
+	if err == nil {
+		queryType, loaded := mDNS.StringToType[valueString]
+		if loaded {
+			*t = DNSQueryType(queryType)
+			return nil
+		}
+	}
+	return E.New("unknown DNS query type: ", string(bytes))
+}
+
+func DNSQueryTypeToString(queryType uint16) string {
+	typeName, loaded := mDNS.TypeToString[queryType]
+	if loaded {
+		return typeName
+	}
+	return F.ToString(queryType)
+}

+ 2 - 1
route/router_dns.go

@@ -50,7 +50,8 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, er
 	}
 	ctx, metadata := adapter.AppendContext(ctx)
 	if len(message.Question) > 0 {
-		switch message.Question[0].Qtype {
+		metadata.QueryType = message.Question[0].Qtype
+		switch metadata.QueryType {
 		case mDNS.TypeA:
 			metadata.IPVersion = 4
 		case mDNS.TypeAAAA:

+ 5 - 0
route/rule_dns.go

@@ -71,6 +71,11 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
 			return nil, E.New("invalid ip version: ", options.IPVersion)
 		}
 	}
+	if len(options.QueryType) > 0 {
+		item := NewQueryTypeItem(options.QueryType)
+		rule.items = append(rule.items, item)
+		rule.allItems = append(rule.allItems, item)
+	}
 	if options.Network != "" {
 		switch options.Network {
 		case N.NetworkTCP, N.NetworkUDP:

+ 47 - 0
route/rule_query_type.go

@@ -0,0 +1,47 @@
+package route
+
+import (
+	"strings"
+
+	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing/common"
+)
+
+var _ RuleItem = (*QueryTypeItem)(nil)
+
+type QueryTypeItem struct {
+	typeList []uint16
+	typeMap  map[uint16]bool
+}
+
+func NewQueryTypeItem(typeList []option.DNSQueryType) *QueryTypeItem {
+	rule := &QueryTypeItem{
+		typeList: common.Map(typeList, func(it option.DNSQueryType) uint16 {
+			return uint16(it)
+		}),
+		typeMap: make(map[uint16]bool),
+	}
+	for _, userId := range rule.typeList {
+		rule.typeMap[userId] = true
+	}
+	return rule
+}
+
+func (r *QueryTypeItem) Match(metadata *adapter.InboundContext) bool {
+	if metadata.QueryType == 0 {
+		return false
+	}
+	return r.typeMap[metadata.QueryType]
+}
+
+func (r *QueryTypeItem) String() string {
+	var description string
+	pLen := len(r.typeList)
+	if pLen == 1 {
+		description = "query_type=" + option.DNSQueryTypeToString(r.typeList[0])
+	} else {
+		description = "query_type=[" + strings.Join(common.Map(r.typeList, option.DNSQueryTypeToString), " ") + "]"
+	}
+	return description
+}