瀏覽代碼

Parse port in dns address

世界 4 年之前
父節點
當前提交
486f96838d
共有 1 個文件被更改,包括 30 次插入18 次删除
  1. 30 18
      infra/conf/dns.go

+ 30 - 18
infra/conf/dns.go

@@ -3,6 +3,7 @@ package conf
 import (
 	"encoding/json"
 	"sort"
+	"strconv"
 	"strings"
 
 	"github.com/xtls/xray-core/app/dns"
@@ -19,28 +20,39 @@ type NameServerConfig struct {
 	ExpectIPs    StringList
 }
 
-func (c *NameServerConfig) UnmarshalJSON(data []byte) error {
+func (c *NameServerConfig) UnmarshalJSON(data []byte) (err error) {
 	var address Address
-	if err := json.Unmarshal(data, &address); err == nil {
+	if err = json.Unmarshal(data, &address); err == nil {
 		c.Address = &address
-		return nil
+	} else {
+		var advanced struct {
+			Address      *Address   `json:"address"`
+			ClientIP     *Address   `json:"clientIp"`
+			Port         uint16     `json:"port"`
+			SkipFallback bool       `json:"skipFallback"`
+			Domains      []string   `json:"domains"`
+			ExpectIPs    StringList `json:"expectIps"`
+		}
+		if err = json.Unmarshal(data, &advanced); err == nil {
+			c.Address = advanced.Address
+			c.ClientIP = advanced.ClientIP
+			c.Port = advanced.Port
+			c.SkipFallback = advanced.SkipFallback
+			c.Domains = advanced.Domains
+			c.ExpectIPs = advanced.ExpectIPs
+		}
 	}
 
-	var advanced struct {
-		Address      *Address   `json:"address"`
-		ClientIP     *Address   `json:"clientIp"`
-		Port         uint16     `json:"port"`
-		SkipFallback bool       `json:"skipFallback"`
-		Domains      []string   `json:"domains"`
-		ExpectIPs    StringList `json:"expectIps"`
-	}
-	if err := json.Unmarshal(data, &advanced); err == nil {
-		c.Address = advanced.Address
-		c.ClientIP = advanced.ClientIP
-		c.Port = advanced.Port
-		c.SkipFallback = advanced.SkipFallback
-		c.Domains = advanced.Domains
-		c.ExpectIPs = advanced.ExpectIPs
+	if err == nil {
+		if c.Port == 0 && c.Address.Family().IsDomain() {
+			if host, port, err := net.SplitHostPort(c.Address.Domain()); err == nil {
+				port, err := strconv.Atoi(port)
+				if err == nil {
+					c.Address = &Address{Address: net.ParseAddress(host)}
+					c.Port = uint16(port)
+				}
+			}
+		}
 		return nil
 	}