浏览代码

优化一些集合扩展函数的性能

懒得勤快 2 年之前
父节点
当前提交
0b0bd17f67
共有 1 个文件被更改,包括 55 次插入20 次删除
  1. 55 20
      Masuit.Tools.Abstractions/Extensions/BaseType/IEnumerableExtensions.cs

+ 55 - 20
Masuit.Tools.Abstractions/Extensions/BaseType/IEnumerableExtensions.cs

@@ -9,6 +9,9 @@ using System.Threading.Tasks;
 
 namespace Masuit.Tools;
 
+/// <summary>
+/// 
+/// </summary>
 public static class IEnumerableExtensions
 {
     /// <summary>
@@ -47,6 +50,7 @@ public static class IEnumerableExtensions
     /// <param name="first"></param>
     /// <param name="second"></param>
     /// <param name="keySelector"></param>
+    /// <param name="comparer"></param>
     /// <returns></returns>
     public static IEnumerable<TSource> IntersectBy<TSource, TKey>(this IEnumerable<TSource> first, IEnumerable<TSource> second, Func<TSource, TKey> keySelector, IEqualityComparer<TKey> comparer)
     {
@@ -85,7 +89,9 @@ public static class IEnumerableExtensions
     /// 多个集合取交集元素
     /// </summary>
     /// <typeparam name="TSource"></typeparam>
+    /// <typeparam name="TKey"></typeparam>
     /// <param name="source"></param>
+    /// <param name="keySelector"></param>
     /// <returns></returns>
     public static IEnumerable<TSource> IntersectAll<TSource, TKey>(this IEnumerable<IEnumerable<TSource>> source, Func<TSource, TKey> keySelector)
     {
@@ -96,7 +102,10 @@ public static class IEnumerableExtensions
     /// 多个集合取交集元素
     /// </summary>
     /// <typeparam name="TSource"></typeparam>
+    /// <typeparam name="TKey"></typeparam>
     /// <param name="source"></param>
+    /// <param name="keySelector"></param>
+    /// <param name="comparer"></param>
     /// <returns></returns>
     public static IEnumerable<TSource> IntersectAll<TSource, TKey>(this IEnumerable<IEnumerable<TSource>> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey> comparer)
     {
@@ -108,6 +117,7 @@ public static class IEnumerableExtensions
     /// </summary>
     /// <typeparam name="T"></typeparam>
     /// <param name="source"></param>
+    /// <param name="comparer"></param>
     /// <returns></returns>
     public static IEnumerable<T> IntersectAll<T>(this IEnumerable<IEnumerable<T>> source, IEqualityComparer<T> comparer)
     {
@@ -167,8 +177,9 @@ public static class IEnumerableExtensions
     /// <param name="first"></param>
     /// <param name="second"></param>
     /// <param name="keySelector"></param>
+    /// <param name="comparer"></param>
     /// <returns></returns>
-    public static IEnumerable<TSource> IntersectBy<TSource, TKey>(this IEnumerable<TSource> first, IEnumerable<TKey> second, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? comparer)
+    public static IEnumerable<TSource> IntersectBy<TSource, TKey>(this IEnumerable<TSource> first, IEnumerable<TKey> second, Func<TSource, TKey> keySelector, IEqualityComparer<TKey> comparer)
     {
         if (first == null)
             throw new ArgumentNullException(nameof(first));
@@ -214,7 +225,7 @@ public static class IEnumerableExtensions
     /// <param name="comparer"></param>
     /// <returns></returns>
     /// <exception cref="ArgumentNullException"></exception>
-    public static IEnumerable<TSource> ExceptBy<TSource, TKey>(this IEnumerable<TSource> first, IEnumerable<TKey> second, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? comparer)
+    public static IEnumerable<TSource> ExceptBy<TSource, TKey>(this IEnumerable<TSource> first, IEnumerable<TKey> second, Func<TSource, TKey> keySelector, IEqualityComparer<TKey> comparer)
     {
         if (first == null)
             throw new ArgumentNullException(nameof(first));
@@ -505,10 +516,17 @@ public static class IEnumerableExtensions
     /// <typeparam name="T"></typeparam>
     /// <param name="source"></param>
     /// <param name="action"></param>
+    /// <param name="cancellationToken"></param>
     /// <returns></returns>
     public static Task ForeachAsync<T>(this IEnumerable<T> source, Func<T, Task> action, CancellationToken cancellationToken = default)
     {
-        return ForeachAsync(source, action, source.Count(), cancellationToken);
+        if (source is ICollection<T> collection)
+        {
+            return ForeachAsync(collection, action, collection.Count, cancellationToken);
+        }
+
+        var list = source.ToList();
+        return ForeachAsync(list, action, list.Count, cancellationToken);
     }
 
     /// <summary>
@@ -556,8 +574,10 @@ public static class IEnumerableExtensions
             tasks.Add(task);
             if (tasks.Count >= maxParallelCount)
             {
-                results.AddRange(await Task.WhenAll(tasks));
-                tasks.Clear();
+                await Task.WhenAny(tasks);
+                var completedTasks = tasks.Where(t => t.IsCompleted).ToArray();
+                results.AddRange(completedTasks.Select(t => t.Result));
+                tasks.RemoveWhere(t => completedTasks.Contains(t));
             }
         }
 
@@ -581,12 +601,15 @@ public static class IEnumerableExtensions
         int index = 0;
         foreach (var item in source)
         {
-            var task = selector(item, index++);
+            var task = selector(item, index);
             tasks.Add(task);
+            Interlocked.Add(ref index, 1);
             if (tasks.Count >= maxParallelCount)
             {
-                results.AddRange(await Task.WhenAll(tasks));
-                tasks.Clear();
+                await Task.WhenAny(tasks);
+                var completedTasks = tasks.Where(t => t.IsCompleted).ToArray();
+                results.AddRange(completedTasks.Select(t => t.Result));
+                tasks.RemoveWhere(t => completedTasks.Contains(t));
             }
         }
 
@@ -629,8 +652,8 @@ public static class IEnumerableExtensions
             Interlocked.Add(ref index, 1);
             if (list.Count >= maxParallelCount)
             {
-                await Task.WhenAll(list);
-                list.Clear();
+                await Task.WhenAny(list);
+                list.RemoveAll(t => t.IsCompleted);
             }
         }
 
@@ -647,7 +670,13 @@ public static class IEnumerableExtensions
     /// <returns></returns>
     public static Task ForAsync<T>(this IEnumerable<T> source, Func<T, int, Task> selector, CancellationToken cancellationToken = default)
     {
-        return ForAsync(source, selector, source.Count(), cancellationToken);
+        if (source is ICollection<T> collection)
+        {
+            return ForAsync(collection, selector, collection.Count, cancellationToken);
+        }
+
+        var list = source.ToList();
+        return ForAsync(list, selector, list.Count, cancellationToken);
     }
 
     /// <summary>
@@ -806,6 +835,7 @@ public static class IEnumerableExtensions
     /// 标准差
     /// </summary>
     /// <typeparam name="T"></typeparam>
+    /// <typeparam name="TResult"></typeparam>
     /// <param name="source"></param>
     /// <param name="selector"></param>
     /// <returns></returns>
@@ -833,11 +863,12 @@ public static class IEnumerableExtensions
     public static double StandardDeviation(this IEnumerable<double> source)
     {
         double result = 0;
-        int count = source.Count();
+        var list = source as ICollection<double> ?? source.ToList();
+        int count = list.Count();
         if (count > 1)
         {
-            double avg = source.Average();
-            double sum = source.Sum(d => (d - avg) * (d - avg));
+            var avg = list.Average();
+            var sum = list.Sum(d => (d - avg) * (d - avg));
             result = Math.Sqrt(sum / count);
         }
 
@@ -959,9 +990,11 @@ public static class IEnumerableExtensions
     {
         first ??= new List<T1>();
         second ??= new List<T2>();
-        var add = first.ExceptBy(second, condition).ToList();
-        var remove = second.ExceptBy(first, (s, f) => condition(f, s)).ToList();
-        var update = first.IntersectBy(second, condition).ToList();
+        var firstSource = first as ICollection<T1> ?? first.ToList();
+        var secondSource = second as ICollection<T2> ?? second.ToList();
+        var add = firstSource.ExceptBy(secondSource, condition).ToList();
+        var remove = secondSource.ExceptBy(firstSource, (s, f) => condition(f, s)).ToList();
+        var update = firstSource.IntersectBy(secondSource, condition).ToList();
         return (add, remove, update);
     }
 
@@ -978,9 +1011,11 @@ public static class IEnumerableExtensions
     {
         first ??= new List<T1>();
         second ??= new List<T2>();
-        var add = first.ExceptBy(second, condition).ToList();
-        var remove = second.ExceptBy(first, (s, f) => condition(f, s)).ToList();
-        var updates = first.IntersectBy(second, condition).Select(t1 => (t1, second.FirstOrDefault(t2 => condition(t1, t2)))).ToList();
+        var firstSource = first as ICollection<T1> ?? first.ToList();
+        var secondSource = second as ICollection<T2> ?? second.ToList();
+        var add = firstSource.ExceptBy(secondSource, condition).ToList();
+        var remove = secondSource.ExceptBy(firstSource, (s, f) => condition(f, s)).ToList();
+        var updates = firstSource.IntersectBy(secondSource, condition).Select(t1 => (t1, secondSource.FirstOrDefault(t2 => condition(t1, t2)))).ToList();
         return (add, remove, updates);
     }