Procházet zdrojové kódy

Fix UDP conn stuck on sniff

This change only avoids permanent hangs. We need to implement read deadlines for UDP conns in 1.10 for server inbounds.
世界 před 1 rokem
rodič
revize
21b1ac26b9
1 změnil soubory, kde provedl 45 přidání a 22 odebrání
  1. 45 22
      route/router.go

+ 45 - 22
route/router.go

@@ -951,34 +951,57 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
 	}*/
 
 	if metadata.InboundOptions.SniffEnabled || metadata.Destination.Addr.IsUnspecified() {
-		buffer := buf.NewPacket()
-		destination, err := conn.ReadPacket(buffer)
+		var (
+			buffer      = buf.NewPacket()
+			destination M.Socksaddr
+			done        = make(chan struct{})
+			err         error
+		)
+		go func() {
+			sniffTimeout := C.ReadPayloadTimeout
+			if metadata.InboundOptions.SniffTimeout > 0 {
+				sniffTimeout = time.Duration(metadata.InboundOptions.SniffTimeout)
+			}
+			conn.SetReadDeadline(time.Now().Add(sniffTimeout))
+			destination, err = conn.ReadPacket(buffer)
+			conn.SetReadDeadline(time.Time{})
+			close(done)
+		}()
+		select {
+		case <-done:
+		case <-ctx.Done():
+			conn.Close()
+			return ctx.Err()
+		}
 		if err != nil {
 			buffer.Release()
-			return err
-		}
-		if metadata.Destination.Addr.IsUnspecified() {
-			metadata.Destination = destination
-		}
-		if metadata.InboundOptions.SniffEnabled {
-			sniffMetadata, _ := sniff.PeekPacket(ctx, buffer.Bytes(), sniff.DomainNameQuery, sniff.QUICClientHello, sniff.STUNMessage)
-			if sniffMetadata != nil {
-				metadata.Protocol = sniffMetadata.Protocol
-				metadata.Domain = sniffMetadata.Domain
-				if metadata.InboundOptions.SniffOverrideDestination && M.IsDomainName(metadata.Domain) {
-					metadata.Destination = M.Socksaddr{
-						Fqdn: metadata.Domain,
-						Port: metadata.Destination.Port,
+			if !errors.Is(err, os.ErrDeadlineExceeded) {
+				return err
+			}
+		} else {
+			if metadata.Destination.Addr.IsUnspecified() {
+				metadata.Destination = destination
+			}
+			if metadata.InboundOptions.SniffEnabled {
+				sniffMetadata, _ := sniff.PeekPacket(ctx, buffer.Bytes(), sniff.DomainNameQuery, sniff.QUICClientHello, sniff.STUNMessage)
+				if sniffMetadata != nil {
+					metadata.Protocol = sniffMetadata.Protocol
+					metadata.Domain = sniffMetadata.Domain
+					if metadata.InboundOptions.SniffOverrideDestination && M.IsDomainName(metadata.Domain) {
+						metadata.Destination = M.Socksaddr{
+							Fqdn: metadata.Domain,
+							Port: metadata.Destination.Port,
+						}
+					}
+					if metadata.Domain != "" {
+						r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain)
+					} else {
+						r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol)
 					}
-				}
-				if metadata.Domain != "" {
-					r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol, ", domain: ", metadata.Domain)
-				} else {
-					r.logger.DebugContext(ctx, "sniffed packet protocol: ", metadata.Protocol)
 				}
 			}
+			conn = bufio.NewCachedPacketConn(conn, buffer, destination)
 		}
-		conn = bufio.NewCachedPacketConn(conn, buffer, destination)
 	}
 	if r.dnsReverseMapping != nil && metadata.Domain == "" {
 		domain, loaded := r.dnsReverseMapping.Query(metadata.Destination.Addr)