Browse Source

Chore: Refactor infra/conf.TestToCidrList() (#4017)

zonescape 11 months ago
parent
commit
ec1fd008c4
2 changed files with 34 additions and 30 deletions
  1. 0 20
      infra/conf/dns_test.go
  2. 34 10
      infra/conf/router_test.go

+ 0 - 20
infra/conf/dns_test.go

@@ -2,35 +2,15 @@ package conf_test
 
 import (
 	"encoding/json"
-	"os"
-	"path/filepath"
 	"testing"
 
 	"github.com/xtls/xray-core/app/dns"
-	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common/net"
-	"github.com/xtls/xray-core/common/platform"
-	"github.com/xtls/xray-core/common/platform/filesystem"
 	. "github.com/xtls/xray-core/infra/conf"
 	"google.golang.org/protobuf/proto"
 )
 
-func init() {
-	wd, err := os.Getwd()
-	common.Must(err)
-
-	if _, err := os.Stat(platform.GetAssetLocation("geoip.dat")); err != nil && os.IsNotExist(err) {
-		common.Must(filesystem.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(wd, "..", "..", "resources", "geoip.dat")))
-	}
-
-	os.Setenv("xray.location.asset", wd)
-}
-
 func TestDNSConfigParsing(t *testing.T) {
-	defer func() {
-		os.Unsetenv("xray.location.asset")
-	}()
-
 	parserCreator := func() func(string) (proto.Message, error) {
 		return func(s string) (proto.Message, error) {
 			config := new(DNSConfig)

+ 34 - 10
infra/conf/router_test.go

@@ -2,6 +2,7 @@ package conf_test
 
 import (
 	"encoding/json"
+	"fmt"
 	"os"
 	"path/filepath"
 	"testing"
@@ -18,21 +19,44 @@ import (
 	"google.golang.org/protobuf/proto"
 )
 
-func init() {
-	wd, err := os.Getwd()
-	common.Must(err)
-
-	if _, err := os.Stat(platform.GetAssetLocation("geoip.dat")); err != nil && os.IsNotExist(err) {
-		common.Must(filesystem.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(wd, "..", "..", "resources", "geoip.dat")))
+func getAssetPath(file string) (string, error) {
+	path := platform.GetAssetLocation(file)
+	_, err := os.Stat(path)
+	if os.IsNotExist(err) {
+		path := filepath.Join("..", "..", "resources", file)
+		_, err := os.Stat(path)
+		if os.IsNotExist(err) {
+			return "", fmt.Errorf("can't find %s in standard asset locations or {project_root}/resources", file)
+		}
+		if err != nil {
+			return "", fmt.Errorf("can't stat %s: %v", path, err)
+		}
+		return path, nil
+	}
+	if err != nil {
+		return "", fmt.Errorf("can't stat %s: %v", path, err)
 	}
 
-	os.Setenv("xray.location.asset", wd)
+	return path, nil
 }
 
 func TestToCidrList(t *testing.T) {
-	t.Log(os.Getenv("xray.location.asset"))
+	tempDir, err := os.MkdirTemp("", "test-")
+	if err != nil {
+		t.Fatalf("can't create temp dir: %v", err)
+	}
+	defer os.RemoveAll(tempDir)
+
+	geoipPath, err := getAssetPath("geoip.dat")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	common.Must(filesystem.CopyFile(filepath.Join(tempDir, "geoip.dat"), geoipPath))
+	common.Must(filesystem.CopyFile(filepath.Join(tempDir, "geoiptestrouter.dat"), geoipPath))
 
-	common.Must(filesystem.CopyFile(platform.GetAssetLocation("geoiptestrouter.dat"), "geoip.dat"))
+	os.Setenv("xray.location.asset", tempDir)
+	defer os.Unsetenv("xray.location.asset")
 
 	ips := StringList([]string{
 		"geoip:us",
@@ -44,7 +68,7 @@ func TestToCidrList(t *testing.T) {
 		"ext-ip:geoiptestrouter.dat:!ca",
 	})
 
-	_, err := ToCidrList(ips)
+	_, err = ToCidrList(ips)
 	if err != nil {
 		t.Fatalf("Failed to parse geoip list, got %s", err)
 	}