Browse Source

refactor: StunMessage5389

Bruce Wayne 2 years ago
parent
commit
38a9df1a88

+ 1 - 1
STUN/Client/StunClient3489.cs

@@ -170,7 +170,7 @@ public class StunClient3489 : IStunClient
 			SocketReceiveMessageFromResult r = await _proxy.ReceiveMessageFromAsync(buffer, SocketFlags.None, receive, cts.Token);
 
 			StunMessage5389 message = new();
-			if (message.TryParse(buffer.Span[..r.ReceivedBytes]) && message.IsSameTransaction(sendMessage))
+			if (message.TryParse(buffer[..r.ReceivedBytes]) && message.IsSameTransaction(sendMessage))
 			{
 				return new StunResponse(message, (IPEndPoint)r.RemoteEndPoint, r.PacketInformation.Address);
 			}

+ 1 - 1
STUN/Client/StunClient5389UDP.cs

@@ -283,7 +283,7 @@ public class StunClient5389UDP : IStunClient
 			SocketReceiveMessageFromResult r = await _proxy.ReceiveMessageFromAsync(buffer, SocketFlags.None, receive, cts.Token);
 
 			StunMessage5389 message = new();
-			if (message.TryParse(buffer.Span[..r.ReceivedBytes]) && message.IsSameTransaction(sendMessage))
+			if (message.TryParse(buffer[..r.ReceivedBytes]) && message.IsSameTransaction(sendMessage))
 			{
 				return new StunResponse(message, (IPEndPoint)r.RemoteEndPoint, r.PacketInformation.Address);
 			}

+ 5 - 1
STUN/Enums/FilteringBehavior.cs

@@ -7,5 +7,9 @@ public enum FilteringBehavior
 	EndpointIndependent,
 	AddressDependent,
 	AddressAndPortDependent,
-	Fail
+
+	/// <summary>
+	/// Filtering test applies only to UDP.
+	/// </summary>
+	None
 }

+ 2 - 2
STUN/Messages/StunAttribute.cs

@@ -58,7 +58,7 @@ public class StunAttribute
 	{
 		if (buffer.Length < 4)
 		{
-			return 0;
+			return default;
 		}
 
 		Type = (AttributeType)BinaryPrimitives.ReadUInt16BigEndian(buffer);
@@ -67,7 +67,7 @@ public class StunAttribute
 
 		if (buffer.Length < 4 + Length)
 		{
-			return 0;
+			return default;
 		}
 
 		ReadOnlySpan<byte> value = buffer.Slice(4, Length);

+ 72 - 34
STUN/Messages/StunMessage5389.cs

@@ -1,5 +1,6 @@
 using Microsoft;
 using STUN.Enums;
+using System.Buffers;
 using System.Buffers.Binary;
 using System.Diagnostics;
 using System.Security.Cryptography;
@@ -13,6 +14,12 @@ public class StunMessage5389
 {
 	#region Header
 
+	private const int SizeOfMessageType = sizeof(StunMessageType);
+	private const int SizeOfLength = sizeof(ushort);
+	private const int SizeOfMagicCookie = sizeof(uint);
+	private const int SizeOfTransactionId = 12;
+	private const int HeaderLength = SizeOfMessageType + SizeOfLength + SizeOfMagicCookie + SizeOfTransactionId;
+
 	public StunMessageType StunMessageType { get; set; }
 
 	public uint MagicCookie { get; set; }
@@ -28,23 +35,23 @@ public class StunMessage5389
 		Attributes = Array.Empty<StunAttribute>();
 		StunMessageType = StunMessageType.BindingRequest;
 		MagicCookie = 0x2112A442;
-		TransactionId = new byte[12];
+		TransactionId = new byte[SizeOfTransactionId];
 		RandomNumberGenerator.Fill(TransactionId);
 	}
 
 	public int WriteTo(Span<byte> buffer)
 	{
-		ushort messageLength = Attributes.Aggregate<StunAttribute, ushort>(0, (current, attribute) => (ushort)(current + attribute.RealLength));
-		int length = 20 + messageLength;
+		ushort messageLength = (ushort)Attributes.Sum(x => x.RealLength);
+		int length = HeaderLength + messageLength;
 		Requires.Range(buffer.Length >= length, nameof(buffer));
 
 		BinaryPrimitives.WriteUInt16BigEndian(buffer, (ushort)StunMessageType);
-		BinaryPrimitives.WriteUInt16BigEndian(buffer[2..], messageLength);
-		BinaryPrimitives.WriteUInt32BigEndian(buffer[4..], MagicCookie);
-		TransactionId.CopyTo(buffer[8..]);
+		BinaryPrimitives.WriteUInt16BigEndian(buffer[SizeOfMessageType..], messageLength);
+		BinaryPrimitives.WriteUInt32BigEndian(buffer[(SizeOfMessageType + SizeOfLength)..], MagicCookie);
+		TransactionId.CopyTo(buffer[(SizeOfMessageType + SizeOfLength + SizeOfMagicCookie)..]);
 
-		buffer = buffer[20..];
-		foreach (StunAttribute? attribute in Attributes)
+		buffer = buffer[HeaderLength..];
+		foreach (StunAttribute attribute in Attributes)
 		{
 			int outLength = attribute.WriteTo(buffer);
 			buffer = buffer[outLength..];
@@ -53,60 +60,91 @@ public class StunMessage5389
 		return length;
 	}
 
-	public bool TryParse(ReadOnlySpan<byte> buffer)
+	public bool TryParse(ReadOnlyMemory<byte> buffer)
+	{
+		ReadOnlySequence<byte> sequence = new(buffer);
+		return TryParse(ref sequence);
+	}
+
+	public bool TryParse(ref ReadOnlySequence<byte> sequence)
 	{
-		if (buffer.Length < 20)
+		if (sequence.Length < HeaderLength)
 		{
 			return false; // Check length
 		}
 
-		Span<byte> tempSpan = stackalloc byte[2];
+		SequenceReader<byte> reader = new(sequence);
+
+		if (!reader.TryReadBigEndian(out short typeValue))
+		{
+			throw Assumes.NotReachable();
+		}
 
-		tempSpan[0] = (byte)(buffer[0] & 0b0011_1111);
-		tempSpan[1] = buffer[1];
-		StunMessageType type = (StunMessageType)BinaryPrimitives.ReadUInt16BigEndian(tempSpan);
+		StunMessageType type = (StunMessageType)(ushort)(typeValue & 0b0011_1111_1111_1111);
 
-		if (!Enum.IsDefined(typeof(StunMessageType), type))
+		if (!Enum.IsDefined(type))
 		{
 			return false;
 		}
 
 		StunMessageType = type;
 
-		ushort length = BinaryPrimitives.ReadUInt16BigEndian(buffer[2..]);
-
-		MagicCookie = BinaryPrimitives.ReadUInt32BigEndian(buffer[4..]);
+		if (!reader.TryReadBigEndian(out short lengthValue))
+		{
+			throw Assumes.NotReachable();
+		}
 
-		buffer.Slice(8, 12).CopyTo(TransactionId);
+		ushort length = (ushort)lengthValue;
 
-		if (buffer.Length != length + 20)
+		if (sequence.Length - HeaderLength < length)
 		{
 			return false; // Check length
 		}
 
-		List<StunAttribute> list = new();
+		if (!reader.TryReadBigEndian(out int magicCookie))
+		{
+			throw Assumes.NotReachable();
+		}
 
-		ReadOnlySpan<byte> attributeBuffer = buffer[20..];
-		ReadOnlySpan<byte> magicCookieAndTransactionId = buffer.Slice(4, 16);
+		MagicCookie = (uint)magicCookie;
 
-		while (attributeBuffer.Length > 0)
+		reader.UnreadSequence.Slice(0, SizeOfTransactionId).CopyTo(TransactionId);
+		reader.Advance(SizeOfTransactionId);
+
+		byte[] tempBuffer = ArrayPool<byte>.Shared.Rent(length + SizeOfMagicCookie + SizeOfTransactionId);
+		try
 		{
-			StunAttribute attribute = new();
-			int offset = attribute.TryParse(attributeBuffer, magicCookieAndTransactionId);
-			if (offset > 0)
+			reader.UnreadSequence.Slice(0, length).CopyTo(tempBuffer);
+			reader.Advance(length);
+			sequence.Slice(SizeOfMessageType + SizeOfLength, SizeOfMagicCookie + SizeOfTransactionId).CopyTo(tempBuffer.AsSpan(length));
+
+			List<StunAttribute> list = new();
+
+			Span<byte> attributeBuffer = tempBuffer.AsSpan(0, length);
+			ReadOnlySpan<byte> magicCookieAndTransactionId = tempBuffer.AsSpan(length, SizeOfMagicCookie + SizeOfTransactionId);
+
+			while (attributeBuffer.Length > default(int))
 			{
+				StunAttribute attribute = new();
+				int offset = attribute.TryParse(attributeBuffer, magicCookieAndTransactionId);
+				if (offset <= default(int))
+				{
+					Debug.WriteLine($@"[Warning] Ignore wrong attribute: {Convert.ToHexString(attributeBuffer)}");
+					break;
+				}
+
 				list.Add(attribute);
 				attributeBuffer = attributeBuffer[offset..];
 			}
-			else
-			{
-				Debug.WriteLine($@"[Warning] Ignore wrong attribute: {Convert.ToHexString(attributeBuffer)}");
-				break;
-			}
-		}
 
-		Attributes = list;
+			Attributes = list;
+		}
+		finally
+		{
+			ArrayPool<byte>.Shared.Return(tempBuffer);
+		}
 
+		sequence = reader.UnreadSequence;
 		return true;
 	}