Explorar el Código

Adjust Trojan & Socks handleUDPPayload

RPRX hace 4 años
padre
commit
fb0e517158
Se han modificado 2 ficheros con 22 adiciones y 12 borrados
  1. 3 2
      proxy/socks/server.go
  2. 19 10
      proxy/trojan/server.go

+ 3 - 2
proxy/socks/server.go

@@ -218,7 +218,8 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
 		conn.Write(udpMessage.Bytes())
 		conn.Write(udpMessage.Bytes())
 	})
 	})
 
 
-	if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
+	inbound := session.InboundFromContext(ctx)
+	if inbound != nil && inbound.Source.IsValid() {
 		newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx))
 		newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx))
 	}
 	}
 
 
@@ -249,7 +250,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
 
 
 			currentPacketCtx := ctx
 			currentPacketCtx := ctx
 			newError("send packet to ", destination, " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
 			newError("send packet to ", destination, " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
-			if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
+			if inbound != nil && inbound.Source.IsValid() {
 				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
 				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
 					From:   inbound.Source,
 					From:   inbound.Source,
 					To:     destination,
 					To:     destination,

+ 19 - 10
proxy/trojan/server.go

@@ -251,7 +251,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
 func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
 func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
 	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
 	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
 		udpPayload := packet.Payload
 		udpPayload := packet.Payload
-		udpPayload.UDP = &packet.Source
+		if udpPayload.UDP == nil {
+			udpPayload.UDP = &packet.Source
+		}
 		common.Must(clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}))
 		common.Must(clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}))
 	})
 	})
 
 
@@ -274,23 +276,30 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
 			}
 			}
 
 
 			mb2, b := buf.SplitFirst(mb)
 			mb2, b := buf.SplitFirst(mb)
+			if b == nil {
+				continue
+			}
 			destination := *b.UDP
 			destination := *b.UDP
-			ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
-				From:   inbound.Source,
-				To:     destination,
-				Status: log.AccessAccepted,
-				Reason: "",
-				Email:  user.Email,
-			})
+
+			currentPacketCtx := ctx
+			if inbound.Source.IsValid() {
+				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
+					From:   inbound.Source,
+					To:     destination,
+					Status: log.AccessAccepted,
+					Reason: "",
+					Email:  user.Email,
+				})
+			}
 			newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx))
 			newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx))
 
 
 			if !buf.Cone || dest == nil {
 			if !buf.Cone || dest == nil {
 				dest = &destination
 				dest = &destination
 			}
 			}
 
 
-			udpServer.Dispatch(ctx, *dest, b) // first packet
+			udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
 			for _, payload := range mb2 {
 			for _, payload := range mb2 {
-				udpServer.Dispatch(ctx, *dest, payload)
+				udpServer.Dispatch(currentPacketCtx, *dest, payload)
 			}
 			}
 		}
 		}
 	}
 	}