浏览代码

Freedom UDP: Fix some cone uses like STUN,... when address is domain (#4942)

https://github.com/XTLS/Xray-core/issues/2962#issuecomment-3120472154
patterniha 3 月之前
父节点
当前提交
10376f5b4d
共有 1 个文件被更改,包括 17 次插入7 次删除
  1. 17 7
      proxy/freedom/freedom.go

+ 17 - 7
proxy/freedom/freedom.go

@@ -285,14 +285,18 @@ func NewPacketReader(conn net.Conn, UDPOverride net.Destination, DialDest net.De
 		counter = statConn.ReadCounter
 	}
 	if c, ok := iConn.(*internet.PacketConnWrapper); ok {
-		isAddrChanged := false
-		if UDPOverride.Address != nil || UDPOverride.Port != 0 || DialDest.Address.Family().IsDomain() {
-			isAddrChanged = true
+		isOverridden := false
+		if UDPOverride.Address != nil || UDPOverride.Port != 0 {
+			isOverridden = true
 		}
+		changedAddress, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
+
 		return &PacketReader{
 			PacketConnWrapper: c,
 			Counter:           counter,
-			IsAddrChanged:     isAddrChanged,
+			IsOverridden:      isOverridden,
+			InitUnchangedAddr: DialDest.Address,
+			InitChangedAddr:   net.ParseAddress(changedAddress),
 		}
 	}
 	return &buf.PacketReader{Reader: conn}
@@ -301,7 +305,9 @@ func NewPacketReader(conn net.Conn, UDPOverride net.Destination, DialDest net.De
 type PacketReader struct {
 	*internet.PacketConnWrapper
 	stats.Counter
-	IsAddrChanged bool
+	IsOverridden      bool
+	InitUnchangedAddr net.Address
+	InitChangedAddr   net.Address
 }
 
 func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
@@ -315,9 +321,13 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 	b.Resize(0, int32(n))
 	// if udp dest addr is changed, we are unable to get the correct src addr
 	// so we don't attach src info to udp packet, break cone behavior, assuming the dial dest is the expected scr addr
-	if !r.IsAddrChanged {
+	if !r.IsOverridden {
+		address := net.IPAddress(d.(*net.UDPAddr).IP)
+		if r.InitChangedAddr == address {
+			address = r.InitUnchangedAddr
+		}
 		b.UDP = &net.Destination{
-			Address: net.IPAddress(d.(*net.UDPAddr).IP),
+			Address: address,
 			Port:    net.Port(d.(*net.UDPAddr).Port),
 			Network: net.Network_UDP,
 		}