Ver Fonte

Refactor: *net.UDPAddr -> *net.Destination

https://t.me/projectXray/111998
RPRX há 4 anos atrás
pai
commit
13ad3fddf6

+ 3 - 2
common/buf/buffer.go

@@ -2,9 +2,9 @@ package buf
 
 import (
 	"io"
-	"net"
 
 	"github.com/xtls/xray-core/common/bytespool"
+	"github.com/xtls/xray-core/common/net"
 )
 
 const (
@@ -21,7 +21,7 @@ type Buffer struct {
 	v     []byte
 	start int32
 	end   int32
-	UDP   *net.UDPAddr
+	UDP   *net.Destination
 }
 
 // New creates a Buffer with 0 length and 2K capacity.
@@ -49,6 +49,7 @@ func (b *Buffer) Release() {
 	b.v = nil
 	b.Clear()
 	pool.Put(p)
+	b.UDP = nil
 }
 
 // Clear clears the content of the buffer, results an empty buffer with

+ 23 - 4
proxy/freedom/freedom.go

@@ -149,7 +149,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		if destination.Network == net.Network_TCP {
 			writer = buf.NewWriter(conn)
 		} else {
-			writer = NewPacketWriter(conn)
+			writer = NewPacketWriter(conn, h, ctx)
 		}
 
 		if err := buf.Copy(input, writer, buf.UpdateActivity(timer)); err != nil {
@@ -215,14 +215,18 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 		return nil, err
 	}
 	b.Resize(0, int32(n))
-	b.UDP = d.(*net.UDPAddr)
+	b.UDP = &net.Destination{
+		Address: net.IPAddress(d.(*net.UDPAddr).IP),
+		Port:    net.Port(d.(*net.UDPAddr).Port),
+		Network: net.Network_UDP,
+	}
 	if r.Counter != nil {
 		r.Counter.Add(int64(n))
 	}
 	return buf.MultiBuffer{b}, nil
 }
 
-func NewPacketWriter(conn net.Conn) buf.Writer {
+func NewPacketWriter(conn net.Conn, h *Handler, ctx context.Context) buf.Writer {
 	iConn := conn
 	statConn, ok := iConn.(*internet.StatCouterConnection)
 	if ok {
@@ -236,6 +240,8 @@ func NewPacketWriter(conn net.Conn) buf.Writer {
 		return &PacketWriter{
 			PacketConnWrapper: c,
 			Counter:           counter,
+			Handler:           h,
+			Context:           ctx,
 		}
 	}
 	return &buf.SequentialWriter{Writer: conn}
@@ -244,6 +250,8 @@ func NewPacketWriter(conn net.Conn) buf.Writer {
 type PacketWriter struct {
 	*internet.PacketConnWrapper
 	stats.Counter
+	*Handler
+	context.Context
 }
 
 func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
@@ -256,7 +264,18 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 		var n int
 		var err error
 		if b.UDP != nil {
-			n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), b.UDP)
+			if w.Handler.config.useIP() && b.UDP.Address.Family().IsDomain() {
+				ip := w.Handler.resolveIP(w.Context, b.UDP.Address.Domain(), nil)
+				if ip != nil {
+					b.UDP.Address = ip
+				}
+			}
+			destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
+			if destAddr == nil {
+				b.Release()
+				continue
+			}
+			n, err = w.PacketConnWrapper.WriteTo(b.Bytes(), destAddr)
 		} else {
 			n, err = w.PacketConnWrapper.Write(b.Bytes())
 		}

+ 7 - 12
proxy/shadowsocks/protocol.go

@@ -235,10 +235,8 @@ func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 		buffer.Release()
 		return nil, err
 	}
-	payload.UDP = &net.UDPAddr{
-		IP:   u.Address.IP(),
-		Port: int(u.Port),
-	}
+	dest := u.Destination()
+	payload.UDP = &dest
 	return buf.MultiBuffer{payload}, nil
 }
 
@@ -254,18 +252,15 @@ func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 		if b == nil {
 			break
 		}
-		var packet *buf.Buffer
-		var err error
+		request := w.Request
 		if b.UDP != nil {
-			request := &protocol.RequestHeader{
+			request = &protocol.RequestHeader{
 				User:    w.Request.User,
-				Address: net.IPAddress(b.UDP.IP),
-				Port:    net.Port(b.UDP.Port),
+				Address: b.UDP.Address,
+				Port:    b.UDP.Port,
 			}
-			packet, err = EncodeUDPPacket(request, b.Bytes())
-		} else {
-			packet, err = EncodeUDPPacket(w.Request, b.Bytes())
 		}
+		packet, err := EncodeUDPPacket(request, b.Bytes())
 		b.Release()
 		if err != nil {
 			buf.ReleaseMulti(mb)

+ 8 - 9
proxy/shadowsocks/server.go

@@ -81,8 +81,8 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 		if payload.UDP != nil {
 			request = &protocol.RequestHeader{
 				User:    request.User,
-				Address: net.IPAddress(payload.UDP.IP),
-				Port:    net.Port(payload.UDP.Port),
+				Address: payload.UDP.Address,
+				Port:    payload.UDP.Port,
 			}
 		}
 
@@ -128,25 +128,24 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 				continue
 			}
 
+			destination := request.Destination()
+
 			currentPacketCtx := ctx
 			if inbound.Source.IsValid() {
 				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
 					From:   inbound.Source,
-					To:     request.Destination(),
+					To:     destination,
 					Status: log.AccessAccepted,
 					Reason: "",
 					Email:  request.User.Email,
 				})
 			}
-			newError("tunnelling request to ", request.Destination()).WriteToLog(session.ExportIDToError(currentPacketCtx))
+			newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(currentPacketCtx))
 
-			data.UDP = &net.UDPAddr{
-				IP:   request.Address.IP(),
-				Port: int(request.Port),
-			}
+			data.UDP = &destination
 
 			if dest.Network == 0 {
-				dest = request.Destination() // JUST FOLLOW THE FIREST PACKET
+				dest = request.Destination() // JUST FOLLOW THE FIRST PACKET
 			}
 
 			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)

+ 9 - 9
proxy/socks/server.go

@@ -202,8 +202,8 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
 		if payload.UDP != nil {
 			request = &protocol.RequestHeader{
 				User:    request.User,
-				Address: net.IPAddress(payload.UDP.IP),
-				Port:    net.Port(payload.UDP.Port),
+				Address: payload.UDP.Address,
+				Port:    payload.UDP.Port,
 			}
 		}
 
@@ -244,24 +244,24 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
 				payload.Release()
 				continue
 			}
+
+			destination := request.Destination()
+
 			currentPacketCtx := ctx
-			newError("send packet to ", request.Destination(), " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
+			newError("send packet to ", destination, " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
 			if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
 				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
 					From:   inbound.Source,
-					To:     request.Destination(),
+					To:     destination,
 					Status: log.AccessAccepted,
 					Reason: "",
 				})
 			}
 
-			payload.UDP = &net.UDPAddr{
-				IP:   request.Address.IP(),
-				Port: int(request.Port),
-			}
+			payload.UDP = &destination
 
 			if dest.Network == 0 {
-				dest = request.Destination() // JUST FOLLOW THE FIREST PACKET
+				dest = destination // JUST FOLLOW THE FIRST PACKET
 			}
 
 			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)

+ 7 - 12
proxy/trojan/protocol.go

@@ -134,12 +134,11 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 		if b == nil {
 			break
 		}
-		target := w.Target
+		target := &w.Target
 		if b.UDP != nil {
-			target.Address = net.IPAddress(b.UDP.IP)
-			target.Port = net.Port(b.UDP.Port)
+			target = b.UDP
 		}
-		if _, err := w.writePacket(b.Bytes(), target); err != nil {
+		if _, err := w.writePacket(b.Bytes(), *target); err != nil {
 			buf.ReleaseMulti(mb)
 			return err
 		}
@@ -155,12 +154,11 @@ func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net
 		if b == nil {
 			break
 		}
-		source := dest
+		source := &dest
 		if b.UDP != nil {
-			source.Address = net.IPAddress(b.UDP.IP)
-			source.Port = net.Port(b.UDP.Port)
+			source = b.UDP
 		}
-		if _, err := w.writePacket(b.Bytes(), source); err != nil {
+		if _, err := w.writePacket(b.Bytes(), *source); err != nil {
 			buf.ReleaseMulti(mb)
 			return err
 		}
@@ -312,10 +310,7 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
 		}
 
 		b := buf.New()
-		b.UDP = &net.UDPAddr{
-			IP:   addr.IP(),
-			Port: int(port.Value()),
-		}
+		b.UDP = &dest
 		mb = append(mb, b)
 		n, err := b.ReadFullFrom(r, int32(length))
 		if err != nil {

+ 1 - 1
proxy/trojan/server.go

@@ -281,7 +281,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
 			newError("tunnelling request to ", p.Target).WriteToLog(session.ExportIDToError(ctx))
 
 			if dest.Network == 0 {
-				dest = p.Target // JUST FOLLOW THE FIREST PACKET
+				dest = p.Target // JUST FOLLOW THE FIRST PACKET
 			}
 
 			for _, b := range p.Buffer {

+ 1 - 1
transport/internet/udp/dispatcher.go

@@ -66,7 +66,7 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) *c
 		cancel()
 		v.RemoveRay(dest)
 	}
-	timer := signal.CancelAfterInactivity(ctx, removeRay, time.Second*4)
+	timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
 	link, _ := v.dispatcher.Dispatch(ctx, dest)
 	entry := &connEntry{
 		link:   link,