Pārlūkot izejas kodu

Tests: Improve geosite & geoip tests (#5502)

https://github.com/XTLS/Xray-core/pull/5488#issuecomment-3711843548
Hossin Asaadi 2 mēneši atpakaļ
vecāks
revīzija
36425d2a6e

+ 26 - 55
app/router/condition_geoip_test.go

@@ -1,40 +1,17 @@
 package router_test
 
 import (
-	"fmt"
 	"os"
 	"path/filepath"
+	"runtime"
 	"testing"
 
 	"github.com/xtls/xray-core/app/router"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/net"
-	"github.com/xtls/xray-core/common/platform"
-	"github.com/xtls/xray-core/common/platform/filesystem"
-	"google.golang.org/protobuf/proto"
+	"github.com/xtls/xray-core/infra/conf"
 )
 
-func getAssetPath(file string) (string, error) {
-	path := platform.GetAssetLocation(file)
-	_, err := os.Stat(path)
-	if os.IsNotExist(err) {
-		path := filepath.Join("..", "..", "resources", file)
-		_, err := os.Stat(path)
-		if os.IsNotExist(err) {
-			return "", fmt.Errorf("can't find %s in standard asset locations or {project_root}/resources", file)
-		}
-		if err != nil {
-			return "", fmt.Errorf("can't stat %s: %v", path, err)
-		}
-		return path, nil
-	}
-	if err != nil {
-		return "", fmt.Errorf("can't stat %s: %v", path, err)
-	}
-
-	return path, nil
-}
-
 func TestGeoIPMatcher(t *testing.T) {
 	cidrList := []*router.CIDR{
 		{Ip: []byte{0, 0, 0, 0}, Prefix: 8},
@@ -182,12 +159,11 @@ func TestGeoIPReverseMatcher(t *testing.T) {
 }
 
 func TestGeoIPMatcher4CN(t *testing.T) {
-	ips, err := loadGeoIP("CN")
+	geo := "geoip:cn"
+	geoip, err := loadGeoIP(geo)
 	common.Must(err)
 
-	matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
-		Cidr: ips,
-	})
+	matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
 	common.Must(err)
 
 	if matcher.Match([]byte{8, 8, 8, 8}) {
@@ -196,12 +172,11 @@ func TestGeoIPMatcher4CN(t *testing.T) {
 }
 
 func TestGeoIPMatcher6US(t *testing.T) {
-	ips, err := loadGeoIP("US")
+	geo := "geoip:us"
+	geoip, err := loadGeoIP(geo)
 	common.Must(err)
 
-	matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
-		Cidr: ips,
-	})
+	matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
 	common.Must(err)
 
 	if !matcher.Match(net.ParseAddress("2001:4860:4860::8888").IP()) {
@@ -209,37 +184,34 @@ func TestGeoIPMatcher6US(t *testing.T) {
 	}
 }
 
-func loadGeoIP(country string) ([]*router.CIDR, error) {
-	path, err := getAssetPath("geoip.dat")
-	if err != nil {
-		return nil, err
-	}
-	geoipBytes, err := filesystem.ReadFile(path)
+func loadGeoIP(geo string) (*router.GeoIP, error) {
+	os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))
+
+	geoip, err := conf.ToCidrList([]string{geo})
 	if err != nil {
 		return nil, err
 	}
 
-	var geoipList router.GeoIPList
-	if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
-		return nil, err
+	if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
+		geoip, err = router.GetGeoIPList(geoip)
+		if err != nil {
+			return nil, err
+		}
 	}
 
-	for _, geoip := range geoipList.Entry {
-		if geoip.CountryCode == country {
-			return geoip.Cidr, nil
-		}
+	if len(geoip) == 0 {
+		panic("country not found: " + geo)
 	}
 
-	panic("country not found: " + country)
+	return geoip[0], nil
 }
 
 func BenchmarkGeoIPMatcher4CN(b *testing.B) {
-	ips, err := loadGeoIP("CN")
+	geo := "geoip:cn"
+	geoip, err := loadGeoIP(geo)
 	common.Must(err)
 
-	matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
-		Cidr: ips,
-	})
+	matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
 	common.Must(err)
 
 	b.ResetTimer()
@@ -250,12 +222,11 @@ func BenchmarkGeoIPMatcher4CN(b *testing.B) {
 }
 
 func BenchmarkGeoIPMatcher6US(b *testing.B) {
-	ips, err := loadGeoIP("US")
+	geo := "geoip:us"
+	geoip, err := loadGeoIP(geo)
 	common.Must(err)
 
-	matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
-		Cidr: ips,
-	})
+	matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
 	common.Must(err)
 
 	b.ResetTimer()

+ 65 - 28
app/router/condition_test.go

@@ -1,20 +1,22 @@
 package router_test
 
 import (
+	"os"
+	"path/filepath"
+	"runtime"
 	"strconv"
 	"testing"
 
+	"github.com/xtls/xray-core/app/router"
 	. "github.com/xtls/xray-core/app/router"
 	"github.com/xtls/xray-core/common"
-	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
-	"github.com/xtls/xray-core/common/platform/filesystem"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/protocol/http"
 	"github.com/xtls/xray-core/common/session"
 	"github.com/xtls/xray-core/features/routing"
 	routing_session "github.com/xtls/xray-core/features/routing/session"
-	"google.golang.org/protobuf/proto"
+	"github.com/xtls/xray-core/infra/conf"
 )
 
 func withBackground() routing.Context {
@@ -300,32 +302,25 @@ func TestRoutingRule(t *testing.T) {
 	}
 }
 
-func loadGeoSite(country string) ([]*Domain, error) {
-	path, err := getAssetPath("geosite.dat")
-	if err != nil {
-		return nil, err
-	}
-	geositeBytes, err := filesystem.ReadFile(path)
-	if err != nil {
-		return nil, err
-	}
+func loadGeoSiteDomains(geo string) ([]*Domain, error) {
+	os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))
 
-	var geositeList GeoSiteList
-	if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil {
+	domains, err := conf.ParseDomainRule(geo)
+	if err != nil {
 		return nil, err
 	}
 
-	for _, site := range geositeList.Entry {
-		if site.CountryCode == country {
-			return site.Domain, nil
+	if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
+		domains, err = router.GetDomainList(domains)
+		if err != nil {
+			return nil, err
 		}
 	}
-
-	return nil, errors.New("country not found: " + country)
+	return domains, nil
 }
 
 func TestChinaSites(t *testing.T) {
-	domains, err := loadGeoSite("CN")
+	domains, err := loadGeoSiteDomains("geosite:cn")
 	common.Must(err)
 
 	acMatcher, err := NewMphMatcherGroup(domains)
@@ -366,8 +361,50 @@ func TestChinaSites(t *testing.T) {
 	}
 }
 
+func TestChinaSitesWithAttrs(t *testing.T) {
+	domains, err := loadGeoSiteDomains("geosite:google@cn")
+	common.Must(err)
+
+	acMatcher, err := NewMphMatcherGroup(domains)
+	common.Must(err)
+
+	type TestCase struct {
+		Domain string
+		Output bool
+	}
+	testCases := []TestCase{
+		{
+			Domain: "google.cn",
+			Output: true,
+		},
+		{
+			Domain: "recaptcha.net",
+			Output: true,
+		},
+		{
+			Domain: "164.com",
+			Output: false,
+		},
+		{
+			Domain: "164.com",
+			Output: false,
+		},
+	}
+
+	for i := 0; i < 1024; i++ {
+		testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false})
+	}
+
+	for _, testCase := range testCases {
+		r := acMatcher.ApplyDomain(testCase.Domain)
+		if r != testCase.Output {
+			t.Error("ACDomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r)
+		}
+	}
+}
+
 func BenchmarkMphDomainMatcher(b *testing.B) {
-	domains, err := loadGeoSite("CN")
+	domains, err := loadGeoSiteDomains("geosite:cn")
 	common.Must(err)
 
 	matcher, err := NewMphMatcherGroup(domains)
@@ -412,11 +449,11 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
 	var geoips []*GeoIP
 
 	{
-		ips, err := loadGeoIP("CN")
+		ips, err := loadGeoIP("geoip:cn")
 		common.Must(err)
 		geoips = append(geoips, &GeoIP{
 			CountryCode: "CN",
-			Cidr:        ips,
+			Cidr:        ips.Cidr,
 		})
 	}
 
@@ -425,25 +462,25 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
 		common.Must(err)
 		geoips = append(geoips, &GeoIP{
 			CountryCode: "JP",
-			Cidr:        ips,
+			Cidr:        ips.Cidr,
 		})
 	}
 
 	{
-		ips, err := loadGeoIP("CA")
+		ips, err := loadGeoIP("geoip:ca")
 		common.Must(err)
 		geoips = append(geoips, &GeoIP{
 			CountryCode: "CA",
-			Cidr:        ips,
+			Cidr:        ips.Cidr,
 		})
 	}
 
 	{
-		ips, err := loadGeoIP("US")
+		ips, err := loadGeoIP("geoip:us")
 		common.Must(err)
 		geoips = append(geoips, &GeoIP{
 			CountryCode: "US",
-			Cidr:        ips,
+			Cidr:        ips.Cidr,
 		})
 	}
 

+ 3 - 3
app/router/config.go

@@ -112,7 +112,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 		domains := rr.Domain
 		if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
 			var err error
-			domains, err = getDomainList(rr.Domain)
+			domains, err = GetDomainList(rr.Domain)
 			if err != nil {
 				return nil, errors.New("failed to build domains from mmap").Base(err)
 			}
@@ -122,7 +122,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 		if err != nil {
 			return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err)
 		}
-		errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)")
+		errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(domains), " domain rule(s)")
 		conds.Add(matcher)
 	}
 
@@ -214,7 +214,7 @@ func GetGeoIPList(ips []*GeoIP) ([]*GeoIP, error) {
 
 }
 
-func getDomainList(domains []*Domain) ([]*Domain, error) {
+func GetDomainList(domains []*Domain) ([]*Domain, error) {
 	domainList := []*Domain{}
 	for _, domain := range domains {
 		val := strings.Split(domain.Value, "_")

+ 4 - 1
common/platform/windows.go

@@ -3,7 +3,9 @@
 
 package platform
 
-import "path/filepath"
+import (
+	"path/filepath"
+)
 
 func LineSeparator() string {
 	return "\r\n"
@@ -12,6 +14,7 @@ func LineSeparator() string {
 // GetAssetLocation searches for `file` in the env dir and the executable dir
 func GetAssetLocation(file string) string {
 	assetPath := NewEnvFlag(AssetLocation).GetValue(getExecutableDir)
+
 	return filepath.Join(assetPath, file)
 }
 

+ 1 - 1
infra/conf/dns.go

@@ -89,7 +89,7 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) {
 	var originalRules []*dns.NameServer_OriginalRule
 
 	for _, rule := range c.Domains {
-		parsedDomain, err := parseDomainRule(rule)
+		parsedDomain, err := ParseDomainRule(rule)
 		if err != nil {
 			return nil, errors.New("invalid domain rule: ", rule).Base(err)
 		}

+ 3 - 3
infra/conf/router.go

@@ -291,7 +291,7 @@ func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, er
 	return filteredDomains, nil
 }
 
-func parseDomainRule(domain string) ([]*router.Domain, error) {
+func ParseDomainRule(domain string) ([]*router.Domain, error) {
 	if strings.HasPrefix(domain, "geosite:") {
 		country := strings.ToUpper(domain[8:])
 		domains, err := loadGeositeWithAttr("geosite.dat", country)
@@ -489,7 +489,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
 
 	if rawFieldRule.Domain != nil {
 		for _, domain := range *rawFieldRule.Domain {
-			rules, err := parseDomainRule(domain)
+			rules, err := ParseDomainRule(domain)
 			if err != nil {
 				return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
 			}
@@ -499,7 +499,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
 
 	if rawFieldRule.Domains != nil {
 		for _, domain := range *rawFieldRule.Domains {
-			rules, err := parseDomainRule(domain)
+			rules, err := ParseDomainRule(domain)
 			if err != nil {
 				return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
 			}