Răsfoiți Sursa

Clean up SequenceEqual.

Bart De Smet 7 ani în urmă
părinte
comite
89484fab0a

+ 30 - 26
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/SequenceEqual.cs

@@ -10,30 +10,27 @@ namespace System.Linq
 {
     public static partial class AsyncEnumerable
     {
-        public static Task<bool> SequenceEqual<TSource>(this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
+        public static Task<bool> SequenceEqual<TSource>(this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second)
         {
             if (first == null)
                 throw Error.ArgumentNull(nameof(first));
             if (second == null)
                 throw Error.ArgumentNull(nameof(second));
-            if (comparer == null)
-                throw Error.ArgumentNull(nameof(comparer));
 
-            return SequenceEqual(first, second, comparer, CancellationToken.None);
+            return SequenceEqualCore(first, second, EqualityComparer<TSource>.Default, CancellationToken.None);
         }
 
-        public static Task<bool> SequenceEqual<TSource>(this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second)
+        public static Task<bool> SequenceEqual<TSource>(this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second, CancellationToken cancellationToken)
         {
             if (first == null)
                 throw Error.ArgumentNull(nameof(first));
             if (second == null)
                 throw Error.ArgumentNull(nameof(second));
 
-            return SequenceEqual(first, second, CancellationToken.None);
+            return SequenceEqualCore(first, second, EqualityComparer<TSource>.Default, cancellationToken);
         }
 
-
-        public static Task<bool> SequenceEqual<TSource>(this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
+        public static Task<bool> SequenceEqual<TSource>(this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
         {
             if (first == null)
                 throw Error.ArgumentNull(nameof(first));
@@ -42,53 +39,60 @@ namespace System.Linq
             if (comparer == null)
                 throw Error.ArgumentNull(nameof(comparer));
 
-            return SequenceEqualCore(first, second, comparer, cancellationToken);
+            return SequenceEqualCore(first, second, comparer, CancellationToken.None);
         }
 
-        public static Task<bool> SequenceEqual<TSource>(this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second, CancellationToken cancellationToken)
+        public static Task<bool> SequenceEqual<TSource>(this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
         {
             if (first == null)
                 throw Error.ArgumentNull(nameof(first));
             if (second == null)
                 throw Error.ArgumentNull(nameof(second));
+            if (comparer == null)
+                throw Error.ArgumentNull(nameof(comparer));
 
-            return first.SequenceEqual(second, EqualityComparer<TSource>.Default, cancellationToken);
+            return SequenceEqualCore(first, second, comparer, cancellationToken);
         }
 
-        private static async Task<bool> SequenceEqualCore<TSource>(IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
+        private static Task<bool> SequenceEqualCore<TSource>(IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
         {
             if (first is ICollection<TSource> firstCol && second is ICollection<TSource> secondCol && firstCol.Count != secondCol.Count)
             {
-                return false;
+                return Task.FromResult(false);
             }
 
-            var e1 = first.GetAsyncEnumerator(cancellationToken);
+            return Core();
 
-            try
+            async Task<bool> Core()
             {
-                var e2 = second.GetAsyncEnumerator(cancellationToken);
+                var e1 = first.GetAsyncEnumerator(cancellationToken);
 
                 try
                 {
-                    while (await e1.MoveNextAsync().ConfigureAwait(false))
+                    var e2 = second.GetAsyncEnumerator(cancellationToken);
+
+                    try
                     {
-                        if (!(await e2.MoveNextAsync().ConfigureAwait(false) && comparer.Equals(e1.Current, e2.Current)))
+                        while (await e1.MoveNextAsync().ConfigureAwait(false))
                         {
-                            return false;
+                            if (!(await e2.MoveNextAsync().ConfigureAwait(false) && comparer.Equals(e1.Current, e2.Current)))
+                            {
+                                return false;
+                            }
                         }
-                    }
 
-                    return !await e2.MoveNextAsync().ConfigureAwait(false);
+                        return !await e2.MoveNextAsync().ConfigureAwait(false);
+                    }
+                    finally
+                    {
+                        await e2.DisposeAsync().ConfigureAwait(false);
+                    }
                 }
                 finally
                 {
-                    await e2.DisposeAsync().ConfigureAwait(false);
+                    await e1.DisposeAsync().ConfigureAwait(false);
                 }
             }
-            finally
-            {
-                await e1.DisposeAsync().ConfigureAwait(false);
-            }
         }
     }
 }