瀏覽代碼

Refactor: Use struct `Int32Range` widely (#3867)

https://github.com/XTLS/Xray-core/pull/3867#issuecomment-2412847316
风扇滑翔翼 1 年之前
父節點
當前提交
82bd5f3046
共有 2 個文件被更改,包括 75 次插入97 次删除
  1. 58 25
      infra/conf/common.go
  2. 17 72
      infra/conf/freedom.go

+ 58 - 25
infra/conf/common.go

@@ -246,43 +246,76 @@ func (v *User) Build() *protocol.User {
 
 
 // Int32Range deserializes from "1-2" or 1, so can deserialize from both int and number.
 // Int32Range deserializes from "1-2" or 1, so can deserialize from both int and number.
 // Negative integers can be passed as sentinel values, but do not parse as ranges.
 // Negative integers can be passed as sentinel values, but do not parse as ranges.
+// Value will be exchanged if From > To, use .Left and .Right to get original value if need.
 type Int32Range struct {
 type Int32Range struct {
-	From int32
-	To   int32
+	Left  int32
+	Right int32
+	From  int32
+	To    int32
 }
 }
 
 
 func (v *Int32Range) UnmarshalJSON(data []byte) error {
 func (v *Int32Range) UnmarshalJSON(data []byte) error {
+	defer v.ensureOrder()
 	var str string
 	var str string
 	var rawint int32
 	var rawint int32
 	if err := json.Unmarshal(data, &str); err == nil {
 	if err := json.Unmarshal(data, &str); err == nil {
-		// for number in string format like "114" or "-1"
-		if value, err := strconv.Atoi(str); err == nil {
-			v.From = int32(value)
-			v.To = int32(value)
+		left, right, err := ParseRangeString(str)
+		if err == nil {
+			v.Left, v.Right = int32(left), int32(right)
 			return nil
 			return nil
 		}
 		}
-		// for empty "", we treat it as 0
-		if str == "" {
-			v.From = 0
-			v.To = 0
-			return nil
-		}
-		// for range value, like "114-514"
-		pair := strings.SplitN(str, "-", 2)
-		if len(pair) == 2 {
-			from, err := strconv.Atoi(pair[0])
-			to, err2 := strconv.Atoi(pair[1])
-			if err == nil && err2 == nil {
-				v.From = int32(from)
-				v.To = int32(to)
-				return nil
-			}
-		}
 	} else if err := json.Unmarshal(data, &rawint); err == nil {
 	} else if err := json.Unmarshal(data, &rawint); err == nil {
-		v.From = rawint
-		v.To = rawint
+		v.Left = rawint
+		v.Right = rawint
 		return nil
 		return nil
 	}
 	}
 
 
 	return errors.New("Invalid integer range, expected either string of form \"1-2\" or plain integer.")
 	return errors.New("Invalid integer range, expected either string of form \"1-2\" or plain integer.")
 }
 }
+
+// ensureOrder() gives value to .From & .To and make sure .From < .To
+func (r *Int32Range) ensureOrder() {
+	r.From, r.To = r.Left, r.Right
+	if r.From > r.To {
+		r.From, r.To = r.To, r.From
+	}
+}
+
+// "-114-514"   →  ["-114","514"]
+// "-1919--810" →  ["-1919","-810"]
+func splitFromSecondDash(s string) []string {
+	parts := strings.SplitN(s, "-", 3)
+	if len(parts) < 3 {
+		return []string{s}
+	}
+	return []string{parts[0] + "-" + parts[1], parts[2]}
+}
+
+// Parse rang in string. Support negative number.
+// eg: "114-514" "-114-514" "-1919--810" "114514" ""(return 0)
+func ParseRangeString(str string) (int, int, error) {
+	// for number in string format like "114" or "-1"
+	if value, err := strconv.Atoi(str); err == nil {
+		return value, value, nil
+	}
+	// for empty "", we treat it as 0
+	if str == "" {
+		return 0, 0, nil
+	}
+	// for range value, like "114-514"
+	var pair []string
+	// Process sth like "-114-514" "-1919--810"
+	if strings.HasPrefix(str, "-") {
+		pair = splitFromSecondDash(str)
+	} else {
+		pair = strings.SplitN(str, "-", 2)
+	}
+	if len(pair) == 2 {
+		left, err := strconv.Atoi(pair[0])
+		right, err2 := strconv.Atoi(pair[1])
+		if err == nil && err2 == nil {
+			return left, right, nil
+		}
+	}
+	return 0, 0, errors.New("invalid range string: ", str)
+}

+ 17 - 72
infra/conf/freedom.go

@@ -3,7 +3,6 @@ package conf
 import (
 import (
 	"encoding/base64"
 	"encoding/base64"
 	"net"
 	"net"
-	"strconv"
 	"strings"
 	"strings"
 
 
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/errors"
@@ -24,9 +23,9 @@ type FreedomConfig struct {
 }
 }
 
 
 type Fragment struct {
 type Fragment struct {
-	Packets  string `json:"packets"`
-	Length   string `json:"length"`
-	Interval string `json:"interval"`
+	Packets  string      `json:"packets"`
+	Length   *Int32Range `json:"length"`
+	Interval *Int32Range `json:"interval"`
 }
 }
 
 
 type Noise struct {
 type Noise struct {
@@ -67,7 +66,6 @@ func (c *FreedomConfig) Build() (proto.Message, error) {
 
 
 	if c.Fragment != nil {
 	if c.Fragment != nil {
 		config.Fragment = new(freedom.Fragment)
 		config.Fragment = new(freedom.Fragment)
-		var err, err2 error
 
 
 		switch strings.ToLower(c.Fragment.Packets) {
 		switch strings.ToLower(c.Fragment.Packets) {
 		case "tlshello":
 		case "tlshello":
@@ -80,75 +78,34 @@ func (c *FreedomConfig) Build() (proto.Message, error) {
 			config.Fragment.PacketsTo = 0
 			config.Fragment.PacketsTo = 0
 		default:
 		default:
 			// TCP Segmentation (range)
 			// TCP Segmentation (range)
-			packetsFromTo := strings.Split(c.Fragment.Packets, "-")
-			if len(packetsFromTo) == 2 {
-				config.Fragment.PacketsFrom, err = strconv.ParseUint(packetsFromTo[0], 10, 64)
-				config.Fragment.PacketsTo, err2 = strconv.ParseUint(packetsFromTo[1], 10, 64)
-			} else {
-				config.Fragment.PacketsFrom, err = strconv.ParseUint(packetsFromTo[0], 10, 64)
-				config.Fragment.PacketsTo = config.Fragment.PacketsFrom
-			}
+			from, to, err := ParseRangeString(c.Fragment.Packets)
 			if err != nil {
 			if err != nil {
 				return nil, errors.New("Invalid PacketsFrom").Base(err)
 				return nil, errors.New("Invalid PacketsFrom").Base(err)
 			}
 			}
-			if err2 != nil {
-				return nil, errors.New("Invalid PacketsTo").Base(err2)
-			}
-			if config.Fragment.PacketsFrom > config.Fragment.PacketsTo {
-				config.Fragment.PacketsFrom, config.Fragment.PacketsTo = config.Fragment.PacketsTo, config.Fragment.PacketsFrom
-			}
+			config.Fragment.PacketsFrom = uint64(from)
+			config.Fragment.PacketsTo = uint64(to)
 			if config.Fragment.PacketsFrom == 0 {
 			if config.Fragment.PacketsFrom == 0 {
 				return nil, errors.New("PacketsFrom can't be 0")
 				return nil, errors.New("PacketsFrom can't be 0")
 			}
 			}
 		}
 		}
 
 
 		{
 		{
-			if c.Fragment.Length == "" {
+			if c.Fragment.Length == nil {
 				return nil, errors.New("Length can't be empty")
 				return nil, errors.New("Length can't be empty")
 			}
 			}
-			lengthMinMax := strings.Split(c.Fragment.Length, "-")
-			if len(lengthMinMax) == 2 {
-				config.Fragment.LengthMin, err = strconv.ParseUint(lengthMinMax[0], 10, 64)
-				config.Fragment.LengthMax, err2 = strconv.ParseUint(lengthMinMax[1], 10, 64)
-			} else {
-				config.Fragment.LengthMin, err = strconv.ParseUint(lengthMinMax[0], 10, 64)
-				config.Fragment.LengthMax = config.Fragment.LengthMin
-			}
-			if err != nil {
-				return nil, errors.New("Invalid LengthMin").Base(err)
-			}
-			if err2 != nil {
-				return nil, errors.New("Invalid LengthMax").Base(err2)
-			}
-			if config.Fragment.LengthMin > config.Fragment.LengthMax {
-				config.Fragment.LengthMin, config.Fragment.LengthMax = config.Fragment.LengthMax, config.Fragment.LengthMin
-			}
+			config.Fragment.LengthMin = uint64(c.Fragment.Length.From)
+			config.Fragment.LengthMax = uint64(c.Fragment.Length.To)
 			if config.Fragment.LengthMin == 0 {
 			if config.Fragment.LengthMin == 0 {
 				return nil, errors.New("LengthMin can't be 0")
 				return nil, errors.New("LengthMin can't be 0")
 			}
 			}
 		}
 		}
 
 
 		{
 		{
-			if c.Fragment.Interval == "" {
+			if c.Fragment.Interval == nil {
 				return nil, errors.New("Interval can't be empty")
 				return nil, errors.New("Interval can't be empty")
 			}
 			}
-			intervalMinMax := strings.Split(c.Fragment.Interval, "-")
-			if len(intervalMinMax) == 2 {
-				config.Fragment.IntervalMin, err = strconv.ParseUint(intervalMinMax[0], 10, 64)
-				config.Fragment.IntervalMax, err2 = strconv.ParseUint(intervalMinMax[1], 10, 64)
-			} else {
-				config.Fragment.IntervalMin, err = strconv.ParseUint(intervalMinMax[0], 10, 64)
-				config.Fragment.IntervalMax = config.Fragment.IntervalMin
-			}
-			if err != nil {
-				return nil, errors.New("Invalid IntervalMin").Base(err)
-			}
-			if err2 != nil {
-				return nil, errors.New("Invalid IntervalMax").Base(err2)
-			}
-			if config.Fragment.IntervalMin > config.Fragment.IntervalMax {
-				config.Fragment.IntervalMin, config.Fragment.IntervalMax = config.Fragment.IntervalMax, config.Fragment.IntervalMin
-			}
+			config.Fragment.IntervalMin = uint64(c.Fragment.Interval.From)
+			config.Fragment.IntervalMax = uint64(c.Fragment.Interval.To)
 		}
 		}
 	}
 	}
 
 
@@ -193,29 +150,17 @@ func (c *FreedomConfig) Build() (proto.Message, error) {
 }
 }
 
 
 func ParseNoise(noise *Noise) (*freedom.Noise, error) {
 func ParseNoise(noise *Noise) (*freedom.Noise, error) {
-	var err, err2 error
+	var err error
 	NConfig := new(freedom.Noise)
 	NConfig := new(freedom.Noise)
 
 
 	switch strings.ToLower(noise.Type) {
 	switch strings.ToLower(noise.Type) {
 	case "rand":
 	case "rand":
-		randValue := strings.Split(noise.Packet, "-")
-		if len(randValue) > 2 {
-			return nil, errors.New("Only 2 values are allowed for rand")
-		}
-		if len(randValue) == 2 {
-			NConfig.LengthMin, err = strconv.ParseUint(randValue[0], 10, 64)
-			NConfig.LengthMax, err2 = strconv.ParseUint(randValue[1], 10, 64)
-		}
-		if len(randValue) == 1 {
-			NConfig.LengthMin, err = strconv.ParseUint(randValue[0], 10, 64)
-			NConfig.LengthMax = NConfig.LengthMin
-		}
+		min, max, err := ParseRangeString(noise.Packet)
 		if err != nil {
 		if err != nil {
-			return nil, errors.New("invalid value for rand LengthMin").Base(err)
-		}
-		if err2 != nil {
-			return nil, errors.New("invalid value for rand LengthMax").Base(err2)
+			return nil, errors.New("invalid value for rand Length").Base(err)
 		}
 		}
+		NConfig.LengthMin = uint64(min)
+		NConfig.LengthMax = uint64(max)
 		if NConfig.LengthMin > NConfig.LengthMax {
 		if NConfig.LengthMin > NConfig.LengthMax {
 			NConfig.LengthMin, NConfig.LengthMax = NConfig.LengthMax, NConfig.LengthMin
 			NConfig.LengthMin, NConfig.LengthMax = NConfig.LengthMax, NConfig.LengthMin
 		}
 		}