소스 검색

Trojan UoT: Fix memory/goroutine leak (#5064)

patterniha 3 달 전
부모
커밋
ea1a3ae8f1
1개의 변경된 파일52개의 추가작업 그리고 36개의 파일을 삭제
  1. 52 36
      proxy/trojan/server.go

+ 52 - 36
proxy/trojan/server.go

@@ -233,7 +233,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
 	sessionPolicy = s.policyManager.ForLevel(user.Level)
 
 	if destination.Network == net.Network_UDP { // handle udp request
-		return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
+		return s.handleUDPPayload(ctx, sessionPolicy, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
 	}
 
 	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
@@ -248,7 +248,11 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
 	return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher)
 }
 
-func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
+func (s *Server) handleUDPPayload(ctx context.Context, sessionPolicy policy.Session, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
+	ctx, cancel := context.WithCancel(ctx)
+	defer cancel()
+	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
+	defer timer.SetTimeout(0)
 	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
 		udpPayload := packet.Payload
 		if udpPayload.UDP == nil {
@@ -257,6 +261,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
 
 		if err := clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}); err != nil {
 			errors.LogWarningInner(ctx, err, "failed to write response")
+			cancel()
+		} else {
+			timer.Update()
 		}
 	})
 	defer udpServer.RemoveRay()
@@ -266,47 +273,56 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
 
 	var dest *net.Destination
 
-	for {
-		select {
-		case <-ctx.Done():
-			return nil
-		default:
-			mb, err := clientReader.ReadMultiBuffer()
-			if err != nil {
-				if errors.Cause(err) != io.EOF {
-					return errors.New("unexpected EOF").Base(err)
-				}
+	requestDone := func() error {
+		for {
+			select {
+			case <-ctx.Done():
 				return nil
-			}
+			default:
+				mb, err := clientReader.ReadMultiBuffer()
+				if err != nil {
+					if errors.Cause(err) != io.EOF {
+						return errors.New("unexpected EOF").Base(err)
+					}
+					return nil
+				}
 
-			mb2, b := buf.SplitFirst(mb)
-			if b == nil {
-				continue
-			}
-			destination := *b.UDP
-
-			currentPacketCtx := ctx
-			if inbound.Source.IsValid() {
-				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
-					From:   inbound.Source,
-					To:     destination,
-					Status: log.AccessAccepted,
-					Reason: "",
-					Email:  user.Email,
-				})
-			}
-			errors.LogInfo(ctx, "tunnelling request to ", destination)
+				mb2, b := buf.SplitFirst(mb)
+				if b == nil {
+					continue
+				}
+				timer.Update()
+				destination := *b.UDP
+
+				currentPacketCtx := ctx
+				if inbound.Source.IsValid() {
+					currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
+						From:   inbound.Source,
+						To:     destination,
+						Status: log.AccessAccepted,
+						Reason: "",
+						Email:  user.Email,
+					})
+				}
+				errors.LogInfo(ctx, "tunnelling request to ", destination)
 
-			if !s.cone || dest == nil {
-				dest = &destination
-			}
+				if !s.cone || dest == nil {
+					dest = &destination
+				}
 
-			udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
-			for _, payload := range mb2 {
-				udpServer.Dispatch(currentPacketCtx, *dest, payload)
+				udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
+				for _, payload := range mb2 {
+					udpServer.Dispatch(currentPacketCtx, *dest, payload)
+				}
 			}
 		}
+
+	}
+
+	if err := task.Run(ctx, requestDone); err != nil {
+		return err
 	}
+	return nil
 }
 
 func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session,