Przeglądaj źródła

Improving GetCountAsync to avoid async machinery.

Bart De Smet 7 lat temu
rodzic
commit
01f79649b6

+ 31 - 26
Ix.NET/Source/System.Linq.Async/System/Linq/AsyncEnumerablePartition.cs

@@ -62,40 +62,45 @@ namespace System.Linq
             await base.DisposeAsync().ConfigureAwait(false);
         }
 
-        public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+        public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
         {
             if (onlyIfCheap)
             {
-                return -1;
+                return TaskExt.MinusOne;
             }
 
-            if (!HasLimit)
+            return Core();
+
+            async Task<int> Core()
             {
-                // If HasLimit is false, we contain everything past _minIndexInclusive.
-                // Therefore, we have to iterate the whole enumerable.
-                return Math.Max(await _source.Count(cancellationToken).ConfigureAwait(false) - _minIndexInclusive, 0);
-            }
+                if (!HasLimit)
+                {
+                    // If HasLimit is false, we contain everything past _minIndexInclusive.
+                    // Therefore, we have to iterate the whole enumerable.
+                    return Math.Max(await _source.Count(cancellationToken).ConfigureAwait(false) - _minIndexInclusive, 0);
+                }
 
-            var en = _source.GetAsyncEnumerator(cancellationToken);
+                var en = _source.GetAsyncEnumerator(cancellationToken);
 
-            try
-            {
-                // We only want to iterate up to _maxIndexInclusive + 1.
-                // Past that, we know the enumerable will be able to fit this partition,
-                // so the count will just be _maxIndexInclusive + 1 - _minIndexInclusive.
-
-                // Note that it is possible for _maxIndexInclusive to be int.MaxValue here,
-                // so + 1 may result in signed integer overflow. We need to handle this.
-                // At the same time, however, we are guaranteed that our max count can fit
-                // in an int because if that is true, then _minIndexInclusive must > 0.
-
-                var count = await SkipAndCountAsync((uint)_maxIndexInclusive + 1, en, cancellationToken).ConfigureAwait(false);
-                Debug.Assert(count != (uint)int.MaxValue + 1 || _minIndexInclusive > 0, "Our return value will be incorrect.");
-                return Math.Max((int)count - _minIndexInclusive, 0);
-            }
-            finally
-            {
-                await en.DisposeAsync().ConfigureAwait(false);
+                try
+                {
+                    // We only want to iterate up to _maxIndexInclusive + 1.
+                    // Past that, we know the enumerable will be able to fit this partition,
+                    // so the count will just be _maxIndexInclusive + 1 - _minIndexInclusive.
+
+                    // Note that it is possible for _maxIndexInclusive to be int.MaxValue here,
+                    // so + 1 may result in signed integer overflow. We need to handle this.
+                    // At the same time, however, we are guaranteed that our max count can fit
+                    // in an int because if that is true, then _minIndexInclusive must > 0.
+
+                    var count = await SkipAndCountAsync((uint)_maxIndexInclusive + 1, en, cancellationToken).ConfigureAwait(false);
+                    Debug.Assert(count != (uint)int.MaxValue + 1 || _minIndexInclusive > 0, "Our return value will be incorrect.");
+                    return Math.Max((int)count - _minIndexInclusive, 0);
+                }
+                finally
+                {
+                    await en.DisposeAsync().ConfigureAwait(false);
+                }
             }
         }
 

+ 18 - 13
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/Concat.cs

@@ -100,29 +100,34 @@ namespace System.Linq
                 return list;
             }
 
-            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
             {
                 if (onlyIfCheap)
                 {
-                    return -1;
+                    return TaskExt.MinusOne;
                 }
 
-                var count = 0;
-                for (var i = 0; ; i++)
+                return Core();
+
+                async Task<int> Core()
                 {
-                    var source = GetAsyncEnumerable(i);
-                    if (source == null)
+                    var count = 0;
+                    for (var i = 0; ; i++)
                     {
-                        break;
-                    }
+                        var source = GetAsyncEnumerable(i);
+                        if (source == null)
+                        {
+                            break;
+                        }
 
-                    checked
-                    {
-                        count += await source.Count(cancellationToken).ConfigureAwait(false);
+                        checked
+                        {
+                            count += await source.Count(cancellationToken).ConfigureAwait(false);
+                        }
                     }
-                }
 
-                return count;
+                    return count;
+                }
             }
 
             public override async ValueTask DisposeAsync()

+ 54 - 24
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/GroupBy.cs

@@ -295,16 +295,21 @@ namespace System.Linq
                 return l.ToList(_resultSelector);
             }
 
-            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
             {
                 if (onlyIfCheap)
                 {
-                    return -1;
+                    return TaskExt.MinusOne;
                 }
 
-                var l = await Internal.Lookup<TKey, TSource>.CreateAsync(_source, _keySelector, _comparer, cancellationToken).ConfigureAwait(false);
+                return Core();
+
+                async Task<int> Core()
+                {
+                    var l = await Internal.Lookup<TKey, TSource>.CreateAsync(_source, _keySelector, _comparer, cancellationToken).ConfigureAwait(false);
 
-                return l.Count;
+                    return l.Count;
+                }
             }
         }
 
@@ -384,16 +389,21 @@ namespace System.Linq
                 return await l.ToList(_resultSelector).ConfigureAwait(false);
             }
 
-            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
             {
                 if (onlyIfCheap)
                 {
-                    return -1;
+                    return TaskExt.MinusOne;
                 }
 
-                var l = await Internal.LookupWithTask<TKey, TSource>.CreateAsync(_source, _keySelector, _comparer, cancellationToken).ConfigureAwait(false);
+                return Core();
 
-                return l.Count;
+                async Task<int> Core()
+                {
+                    var l = await Internal.LookupWithTask<TKey, TSource>.CreateAsync(_source, _keySelector, _comparer, cancellationToken).ConfigureAwait(false);
+
+                    return l.Count;
+                }
             }
         }
 
@@ -473,16 +483,21 @@ namespace System.Linq
                 return await l.ToListAsync(cancellationToken).ConfigureAwait(false);
             }
 
-            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
             {
                 if (onlyIfCheap)
                 {
-                    return -1;
+                    return TaskExt.MinusOne;
                 }
 
-                var l = await Internal.Lookup<TKey, TElement>.CreateAsync(_source, _keySelector, _elementSelector, _comparer, cancellationToken).ConfigureAwait(false);
+                return Core();
 
-                return l.Count;
+                async Task<int> Core()
+                {
+                    var l = await Internal.Lookup<TKey, TElement>.CreateAsync(_source, _keySelector, _elementSelector, _comparer, cancellationToken).ConfigureAwait(false);
+
+                    return l.Count;
+                }
             }
         }
 
@@ -562,16 +577,21 @@ namespace System.Linq
                 return await l.ToListAsync(cancellationToken).ConfigureAwait(false);
             }
 
-            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
             {
                 if (onlyIfCheap)
                 {
-                    return -1;
+                    return TaskExt.MinusOne;
                 }
 
-                var l = await Internal.LookupWithTask<TKey, TElement>.CreateAsync(_source, _keySelector, _elementSelector, _comparer, cancellationToken).ConfigureAwait(false);
+                return Core();
 
-                return l.Count;
+                async Task<int> Core()
+                {
+                    var l = await Internal.LookupWithTask<TKey, TElement>.CreateAsync(_source, _keySelector, _elementSelector, _comparer, cancellationToken).ConfigureAwait(false);
+
+                    return l.Count;
+                }
             }
         }
 
@@ -647,16 +667,21 @@ namespace System.Linq
                 return await l.ToListAsync(cancellationToken).ConfigureAwait(false);
             }
 
-            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
             {
                 if (onlyIfCheap)
                 {
-                    return -1;
+                    return TaskExt.MinusOne;
                 }
 
-                var l = await Internal.Lookup<TKey, TSource>.CreateAsync(_source, _keySelector, _comparer, cancellationToken).ConfigureAwait(false);
+                return Core();
+
+                async Task<int> Core()
+                {
+                    var l = await Internal.Lookup<TKey, TSource>.CreateAsync(_source, _keySelector, _comparer, cancellationToken).ConfigureAwait(false);
 
-                return l.Count;
+                    return l.Count;
+                }
             }
         }
 
@@ -732,16 +757,21 @@ namespace System.Linq
                 return await l.ToListAsync(cancellationToken).ConfigureAwait(false);
             }
 
-            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
             {
                 if (onlyIfCheap)
                 {
-                    return -1;
+                    return TaskExt.MinusOne;
                 }
 
-                var l = await Internal.LookupWithTask<TKey, TSource>.CreateAsync(_source, _keySelector, _comparer, cancellationToken).ConfigureAwait(false);
+                return Core();
 
-                return l.Count;
+                async Task<int> Core()
+                {
+                    var l = await Internal.LookupWithTask<TKey, TSource>.CreateAsync(_source, _keySelector, _comparer, cancellationToken).ConfigureAwait(false);
+
+                    return l.Count;
+                }
             }
         }
     }

+ 14 - 9
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/Select.cs

@@ -471,26 +471,31 @@ namespace System.Linq
                 await base.DisposeAsync().ConfigureAwait(false);
             }
 
-            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            public Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
             {
                 if (onlyIfCheap)
                 {
-                    return -1;
+                    return TaskExt.MinusOne;
                 }
 
-                var count = 0;
+                return Core();
 
-                foreach (var item in _source)
+                async Task<int> Core()
                 {
-                    await _selector(item).ConfigureAwait(false);
+                    var count = 0;
 
-                    checked
+                    foreach (var item in _source)
                     {
-                        count++;
+                        await _selector(item).ConfigureAwait(false);
+
+                        checked
+                        {
+                            count++;
+                        }
                     }
-                }
 
-                return count;
+                    return count;
+                }
             }
 
             public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, Task<TResult1>> selector)

+ 1 - 0
Ix.NET/Source/System.Linq.Async/System/Threading/Tasks/TaskExt.cs

@@ -9,5 +9,6 @@ namespace System.Threading.Tasks
         public static readonly ValueTask<bool> True = new ValueTask<bool>(true);
         public static readonly ValueTask<bool> False = new ValueTask<bool>(false);
         public static readonly ValueTask CompletedTask = new ValueTask(Task.FromResult(true));
+        public static readonly Task<int> MinusOne = Task.FromResult(-1);
     }
 }