فهرست منبع

Add API to dump AdGuard rules

世界 4 ماه پیش
والد
کامیت
bea9048cfe

+ 1 - 1
cmd/sing-box/cmd_rule_set_convert.go

@@ -54,7 +54,7 @@ func convertRuleSet(sourcePath string) error {
 	var rules []option.HeadlessRule
 	switch flagRuleSetConvertType {
 	case "adguard":
-		rules, err = adguard.Convert(reader, log.StdLogger())
+		rules, err = adguard.ToOptions(reader, log.StdLogger())
 	case "":
 		return E.New("source type is required")
 	default:

+ 24 - 0
cmd/sing-box/cmd_rule_set_decompile.go

@@ -6,7 +6,10 @@ import (
 	"strings"
 
 	"github.com/sagernet/sing-box/common/srs"
+	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/log"
+	"github.com/sagernet/sing-box/option"
+	E "github.com/sagernet/sing/common/exceptions"
 	"github.com/sagernet/sing/common/json"
 
 	"github.com/spf13/cobra"
@@ -50,6 +53,11 @@ func decompileRuleSet(sourcePath string) error {
 	if err != nil {
 		return err
 	}
+	if hasRule(ruleSet.Options.Rules, func(rule option.DefaultHeadlessRule) bool {
+		return len(rule.AdGuardDomain) > 0
+	}) {
+		return E.New("unable to decompile binary AdGuard rules to rule-set.")
+	}
 	var outputPath string
 	if flagRuleSetDecompileOutput == flagRuleSetDecompileDefaultOutput {
 		if strings.HasSuffix(sourcePath, ".srs") {
@@ -75,3 +83,19 @@ func decompileRuleSet(sourcePath string) error {
 	outputFile.Close()
 	return nil
 }
+
+func hasRule(rules []option.HeadlessRule, cond func(rule option.DefaultHeadlessRule) bool) bool {
+	for _, rule := range rules {
+		switch rule.Type {
+		case C.RuleTypeDefault:
+			if cond(rule.DefaultOptions) {
+				return true
+			}
+		case C.RuleTypeLogical:
+			if hasRule(rule.LogicalOptions.Rules, cond) {
+				return true
+			}
+		}
+	}
+	return false
+}

+ 132 - 46
common/convertor/adguard/convertor.go

@@ -2,6 +2,7 @@ package adguard
 
 import (
 	"bufio"
+	"bytes"
 	"io"
 	"net/netip"
 	"os"
@@ -27,7 +28,7 @@ type agdguardRuleLine struct {
 	isImportant bool
 }
 
-func Convert(reader io.Reader, logger logger.Logger) ([]option.HeadlessRule, error) {
+func ToOptions(reader io.Reader, logger logger.Logger) ([]option.HeadlessRule, error) {
 	scanner := bufio.NewScanner(reader)
 	var (
 		ruleLines    []agdguardRuleLine
@@ -36,45 +37,12 @@ func Convert(reader io.Reader, logger logger.Logger) ([]option.HeadlessRule, err
 parseLine:
 	for scanner.Scan() {
 		ruleLine := scanner.Text()
-
-		// Empty line
 		if ruleLine == "" {
 			continue
 		}
-		// Comment (both line comment and in-line comment)
-		if strings.Contains(ruleLine, "!") {
-			continue
-		}
-		// Either comment or cosmetic filter
-		if strings.Contains(ruleLine, "#") {
-			ignoredLines++
-			logger.Debug("ignored unsupported cosmetic filter: ", ruleLine)
-			continue
-		}
-		// We don't support URL query anyway
-		if strings.Contains(ruleLine, "?") || strings.Contains(ruleLine, "&") {
-			ignoredLines++
-			logger.Debug("ignored unsupported rule with query: ", ruleLine)
-			continue
-		}
-		// Commonly seen in CSS selectors of cosmetic filters
-		if strings.Contains(ruleLine, "[") || strings.Contains(ruleLine, "]") {
-			ignoredLines++
-			logger.Debug("ignored unsupported cosmetic filter: ", ruleLine)
-			continue
-		}
-		if strings.Contains(ruleLine, "(") || strings.Contains(ruleLine, ")") {
-			ignoredLines++
-			logger.Debug("ignored unsupported cosmetic filter: ", ruleLine)
+		if strings.HasPrefix(ruleLine, "!") || strings.HasPrefix(ruleLine, "#") {
 			continue
 		}
-		// We don't support $domain modifier
-		if strings.Contains(ruleLine, "~") {
-			ignoredLines++
-			logger.Debug("ignored unsupported rule modifier: ", ruleLine)
-			continue
-		}
-
 		originRuleLine := ruleLine
 		if M.IsDomainName(ruleLine) {
 			ruleLines = append(ruleLines, agdguardRuleLine{
@@ -128,7 +96,7 @@ parseLine:
 				}
 				if !ignored {
 					ignoredLines++
-					logger.Debug("ignored unsupported rule with modifier: ", paramParts[0], ": ", ruleLine)
+					logger.Debug("ignored unsupported rule with modifier: ", paramParts[0], ": ", originRuleLine)
 					continue parseLine
 				}
 			}
@@ -156,17 +124,35 @@ parseLine:
 			ruleLine = ruleLine[1 : len(ruleLine)-1]
 			if ignoreIPCIDRRegexp(ruleLine) {
 				ignoredLines++
-				logger.Debug("ignored unsupported rule with IPCIDR regexp: ", ruleLine)
+				logger.Debug("ignored unsupported rule with IPCIDR regexp: ", originRuleLine)
 				continue
 			}
 			isRegexp = true
 		} else {
 			if strings.Contains(ruleLine, "://") {
 				ruleLine = common.SubstringAfter(ruleLine, "://")
+				isSuffix = true
 			}
 			if strings.Contains(ruleLine, "/") {
 				ignoredLines++
-				logger.Debug("ignored unsupported rule with path: ", ruleLine)
+				logger.Debug("ignored unsupported rule with path: ", originRuleLine)
+				continue
+			}
+			if strings.Contains(ruleLine, "?") || strings.Contains(ruleLine, "&") {
+				ignoredLines++
+				logger.Debug("ignored unsupported rule with query: ", originRuleLine)
+				continue
+			}
+			if strings.Contains(ruleLine, "[") || strings.Contains(ruleLine, "]") ||
+				strings.Contains(ruleLine, "(") || strings.Contains(ruleLine, ")") ||
+				strings.Contains(ruleLine, "!") || strings.Contains(ruleLine, "#") {
+				ignoredLines++
+				logger.Debug("ignored unsupported cosmetic filter: ", originRuleLine)
+				continue
+			}
+			if strings.Contains(ruleLine, "~") {
+				ignoredLines++
+				logger.Debug("ignored unsupported rule modifier: ", originRuleLine)
 				continue
 			}
 			var domainCheck string
@@ -185,13 +171,13 @@ parseLine:
 					_, ipErr := parseADGuardIPCIDRLine(ruleLine)
 					if ipErr == nil {
 						ignoredLines++
-						logger.Debug("ignored unsupported rule with IPCIDR: ", ruleLine)
+						logger.Debug("ignored unsupported rule with IPCIDR: ", originRuleLine)
 						continue
 					}
 					if M.ParseSocksaddr(domainCheck).Port != 0 {
-						logger.Debug("ignored unsupported rule with port: ", ruleLine)
+						logger.Debug("ignored unsupported rule with port: ", originRuleLine)
 					} else {
-						logger.Debug("ignored unsupported rule with invalid domain: ", ruleLine)
+						logger.Debug("ignored unsupported rule with invalid domain: ", originRuleLine)
 					}
 					ignoredLines++
 					continue
@@ -309,10 +295,112 @@ parseLine:
 			},
 		}
 	}
-	logger.Info("parsed rules: ", len(ruleLines), "/", len(ruleLines)+ignoredLines)
+	if ignoredLines > 0 {
+		logger.Info("parsed rules: ", len(ruleLines), "/", len(ruleLines)+ignoredLines)
+	}
 	return []option.HeadlessRule{currentRule}, nil
 }
 
+var ErrInvalid = E.New("invalid binary AdGuard rule-set")
+
+func FromOptions(rules []option.HeadlessRule) ([]byte, error) {
+	if len(rules) != 1 {
+		return nil, ErrInvalid
+	}
+	rule := rules[0]
+	var (
+		importantDomain             []string
+		importantDomainRegex        []string
+		importantExcludeDomain      []string
+		importantExcludeDomainRegex []string
+		domain                      []string
+		domainRegex                 []string
+		excludeDomain               []string
+		excludeDomainRegex          []string
+	)
+parse:
+	for {
+		switch rule.Type {
+		case C.RuleTypeLogical:
+			if !(len(rule.LogicalOptions.Rules) == 2 && rule.LogicalOptions.Rules[0].Type == C.RuleTypeDefault) {
+				return nil, ErrInvalid
+			}
+			if rule.LogicalOptions.Mode == C.LogicalTypeAnd && rule.LogicalOptions.Rules[0].DefaultOptions.Invert {
+				if len(importantExcludeDomain) == 0 && len(importantExcludeDomainRegex) == 0 {
+					importantExcludeDomain = rule.LogicalOptions.Rules[0].DefaultOptions.AdGuardDomain
+					importantExcludeDomainRegex = rule.LogicalOptions.Rules[0].DefaultOptions.DomainRegex
+					if len(importantExcludeDomain)+len(importantExcludeDomainRegex) == 0 {
+						return nil, ErrInvalid
+					}
+				} else {
+					excludeDomain = rule.LogicalOptions.Rules[0].DefaultOptions.AdGuardDomain
+					excludeDomainRegex = rule.LogicalOptions.Rules[0].DefaultOptions.DomainRegex
+					if len(excludeDomain)+len(excludeDomainRegex) == 0 {
+						return nil, ErrInvalid
+					}
+				}
+			} else if rule.LogicalOptions.Mode == C.LogicalTypeOr && !rule.LogicalOptions.Rules[0].DefaultOptions.Invert {
+				importantDomain = rule.LogicalOptions.Rules[0].DefaultOptions.AdGuardDomain
+				importantDomainRegex = rule.LogicalOptions.Rules[0].DefaultOptions.DomainRegex
+				if len(importantDomain)+len(importantDomainRegex) == 0 {
+					return nil, ErrInvalid
+				}
+			} else {
+				return nil, ErrInvalid
+			}
+			rule = rule.LogicalOptions.Rules[1]
+		case C.RuleTypeDefault:
+			domain = rule.DefaultOptions.AdGuardDomain
+			domainRegex = rule.DefaultOptions.DomainRegex
+			if len(domain)+len(domainRegex) == 0 {
+				return nil, ErrInvalid
+			}
+			break parse
+		}
+	}
+	var output bytes.Buffer
+	for _, ruleLine := range importantDomain {
+		output.WriteString(ruleLine)
+		output.WriteString("$important\n")
+	}
+	for _, ruleLine := range importantDomainRegex {
+		output.WriteString("/")
+		output.WriteString(ruleLine)
+		output.WriteString("/$important\n")
+
+	}
+	for _, ruleLine := range importantExcludeDomain {
+		output.WriteString("@@")
+		output.WriteString(ruleLine)
+		output.WriteString("$important\n")
+	}
+	for _, ruleLine := range importantExcludeDomainRegex {
+		output.WriteString("@@/")
+		output.WriteString(ruleLine)
+		output.WriteString("/$important\n")
+	}
+	for _, ruleLine := range domain {
+		output.WriteString(ruleLine)
+		output.WriteString("\n")
+	}
+	for _, ruleLine := range domainRegex {
+		output.WriteString("/")
+		output.WriteString(ruleLine)
+		output.WriteString("/\n")
+	}
+	for _, ruleLine := range excludeDomain {
+		output.WriteString("@@")
+		output.WriteString(ruleLine)
+		output.WriteString("\n")
+	}
+	for _, ruleLine := range excludeDomainRegex {
+		output.WriteString("@@/")
+		output.WriteString(ruleLine)
+		output.WriteString("/\n")
+	}
+	return output.Bytes(), nil
+}
+
 func ignoreIPCIDRRegexp(ruleLine string) bool {
 	if strings.HasPrefix(ruleLine, "(http?:\\/\\/)") {
 		ruleLine = ruleLine[12:]
@@ -320,11 +408,9 @@ func ignoreIPCIDRRegexp(ruleLine string) bool {
 		ruleLine = ruleLine[13:]
 	} else if strings.HasPrefix(ruleLine, "^") {
 		ruleLine = ruleLine[1:]
-	} else {
-		return false
 	}
-	_, parseErr := strconv.ParseUint(common.SubstringBefore(ruleLine, "\\."), 10, 8)
-	return parseErr == nil
+	return common.Error(strconv.ParseUint(common.SubstringBefore(ruleLine, "\\."), 10, 8)) == nil ||
+		common.Error(strconv.ParseUint(common.SubstringBefore(ruleLine, "."), 10, 8)) == nil
 }
 
 func parseAdGuardHostLine(ruleLine string) (string, error) {

+ 10 - 7
common/convertor/adguard/convertor_test.go

@@ -14,7 +14,8 @@ import (
 
 func TestConverter(t *testing.T) {
 	t.Parallel()
-	rules, err := Convert(strings.NewReader(`
+	ruleString := `||sagernet.org^$important
+@@|sing-box.sagernet.org^$important
 ||example.org^
 |example.com^
 example.net^
@@ -22,10 +23,9 @@ example.net^
 ||example.edu.tw^
 |example.gov
 example.arpa
-@@|sagernet.example.org|
-||sagernet.org^$important
-@@|sing-box.sagernet.org^$important
-`), logger.NOP())
+@@|sagernet.example.org^
+`
+	rules, err := ToOptions(strings.NewReader(ruleString), logger.NOP())
 	require.NoError(t, err)
 	require.Len(t, rules, 1)
 	rule, err := rule.NewHeadlessRule(context.Background(), rules[0])
@@ -76,11 +76,14 @@ example.arpa
 			Domain: domain,
 		}), domain)
 	}
+	ruleFromOptions, err := FromOptions(rules)
+	require.NoError(t, err)
+	require.Equal(t, ruleString, string(ruleFromOptions))
 }
 
 func TestHosts(t *testing.T) {
 	t.Parallel()
-	rules, err := Convert(strings.NewReader(`
+	rules, err := ToOptions(strings.NewReader(`
 127.0.0.1 localhost
 ::1 localhost #[IPv6]
 0.0.0.0 google.com
@@ -111,7 +114,7 @@ func TestHosts(t *testing.T) {
 
 func TestSimpleHosts(t *testing.T) {
 	t.Parallel()
-	rules, err := Convert(strings.NewReader(`
+	rules, err := ToOptions(strings.NewReader(`
 example.com
 www.example.org
 `), logger.NOP())

+ 3 - 4
common/srs/binary.go

@@ -215,16 +215,15 @@ func readDefaultRule(reader varbin.Reader, recover bool) (rule option.DefaultHea
 		case ruleItemWIFIBSSID:
 			rule.WIFIBSSID, err = readRuleItemString(reader)
 		case ruleItemAdGuardDomain:
-			if recover {
-				err = E.New("unable to decompile binary AdGuard rules to rule-set")
-				return
-			}
 			var matcher *domain.AdGuardMatcher
 			matcher, err = domain.ReadAdGuardMatcher(reader)
 			if err != nil {
 				return
 			}
 			rule.AdGuardDomainMatcher = matcher
+			if recover {
+				rule.AdGuardDomain = matcher.Dump()
+			}
 		case ruleItemNetworkType:
 			rule.NetworkType, err = readRuleItemUint8[option.InterfaceType](reader)
 		case ruleItemNetworkIsExpensive: