Răsfoiți Sursa

Optimize Contains.

Bart De Smet 7 ani în urmă
părinte
comite
f7bd7fed23

+ 52 - 9
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/Contains.cs

@@ -15,25 +15,25 @@ namespace System.Linq
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
 
-            return ContainsCore(source, value, EqualityComparer<TSource>.Default, CancellationToken.None);
+            return ContainsCore(source, value, CancellationToken.None);
         }
 
-        public static Task<bool> Contains<TSource>(this IAsyncEnumerable<TSource> source, TSource value, IEqualityComparer<TSource> comparer)
+        public static Task<bool> Contains<TSource>(this IAsyncEnumerable<TSource> source, TSource value, CancellationToken cancellationToken)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
-            if (comparer == null)
-                throw new ArgumentNullException(nameof(comparer));
 
-            return ContainsCore(source, value, comparer, CancellationToken.None);
+            return ContainsCore(source, value, cancellationToken);
         }
 
-        public static Task<bool> Contains<TSource>(this IAsyncEnumerable<TSource> source, TSource value, CancellationToken cancellationToken)
+        public static Task<bool> Contains<TSource>(this IAsyncEnumerable<TSource> source, TSource value, IEqualityComparer<TSource> comparer)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
+            if (comparer == null)
+                throw new ArgumentNullException(nameof(comparer));
 
-            return ContainsCore(source, value, EqualityComparer<TSource>.Default, cancellationToken);
+            return ContainsCore(source, value, comparer, CancellationToken.None);
         }
 
         public static Task<bool> Contains<TSource>(this IAsyncEnumerable<TSource> source, TSource value, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
@@ -46,9 +46,52 @@ namespace System.Linq
             return ContainsCore(source, value, comparer, cancellationToken);
         }
 
-        private static Task<bool> ContainsCore<TSource>(IAsyncEnumerable<TSource> source, TSource value, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
+        private static Task<bool> ContainsCore<TSource>(IAsyncEnumerable<TSource> source, TSource value, CancellationToken cancellationToken)
+        {
+            if (source is ICollection<TSource> collection)
+            {
+                return Task.FromResult(collection.Contains(value));
+            }
+
+            return ContainsCore(source, value, comparer: null, cancellationToken);
+        }
+
+        private static async Task<bool> ContainsCore<TSource>(IAsyncEnumerable<TSource> source, TSource value, IEqualityComparer<TSource> comparer, CancellationToken cancellationToken)
         {
-            return source.Any(x => comparer.Equals(x, value), cancellationToken);
+            var e = source.GetAsyncEnumerator(cancellationToken);
+
+            try
+            {
+                //
+                // See https://github.com/dotnet/corefx/pull/25097 for the optimization here.
+                //
+                if (comparer == null)
+                {
+                    while (await e.MoveNextAsync().ConfigureAwait(false))
+                    {
+                        if (EqualityComparer<TSource>.Default.Equals(e.Current, value))
+                        {
+                            return true;
+                        }
+                    }
+                }
+                else
+                {
+                    while (await e.MoveNextAsync().ConfigureAwait(false))
+                    {
+                        if (comparer.Equals(e.Current, value))
+                        {
+                            return true;
+                        }
+                    }
+                }
+            }
+            finally
+            {
+                await e.DisposeAsync().ConfigureAwait(false);
+            }
+
+            return false;
         }
     }
 }