浏览代码

Optimize OfType and Cast.

Bart De Smet 7 年之前
父节点
当前提交
679cfe7942

+ 56 - 1
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/Cast.cs

@@ -3,6 +3,9 @@
 // See the LICENSE file in the project root for more information. 
 
 using System.Collections.Generic;
+using System.Diagnostics;
+using System.Threading;
+using System.Threading.Tasks;
 
 namespace System.Linq
 {
@@ -19,7 +22,59 @@ namespace System.Linq
                 return typedSource;
             }
 
-            return source.Select(x => (TResult)x);
+            return new CastAsyncIterator<TResult>(source);
+        }
+
+        internal sealed class CastAsyncIterator<TResult> : AsyncIterator<TResult>
+        {
+            private readonly IAsyncEnumerable<object> _source;
+            private IAsyncEnumerator<object> _enumerator;
+
+            public CastAsyncIterator(IAsyncEnumerable<object> source)
+            {
+                Debug.Assert(source != null);
+
+                _source = source;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new CastAsyncIterator<TResult>(_source);
+            }
+
+            public override async ValueTask DisposeAsync()
+            {
+                if (_enumerator != null)
+                {
+                    await _enumerator.DisposeAsync().ConfigureAwait(false);
+                    _enumerator = null;
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            protected override async ValueTask<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        _enumerator = _source.GetAsyncEnumerator(cancellationToken);
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        if (await _enumerator.MoveNextAsync().ConfigureAwait(false))
+                        {
+                            current = (TResult)_enumerator.Current;
+                            return true;
+                        }
+
+                        await DisposeAsync().ConfigureAwait(false);
+                        break;
+                }
+
+                return false;
+            }
         }
     }
 }

+ 61 - 2
Ix.NET/Source/System.Linq.Async/System/Linq/Operators/OfType.cs

@@ -3,17 +3,76 @@
 // See the LICENSE file in the project root for more information. 
 
 using System.Collections.Generic;
+using System.Diagnostics;
+using System.Threading;
+using System.Threading.Tasks;
 
 namespace System.Linq
 {
     public static partial class AsyncEnumerable
     {
-        public static IAsyncEnumerable<TType> OfType<TType>(this IAsyncEnumerable<object> source)
+        public static IAsyncEnumerable<TResult> OfType<TResult>(this IAsyncEnumerable<object> source)
         {
             if (source == null)
                 throw Error.ArgumentNull(nameof(source));
 
-            return source.Where(x => x is TType).Cast<TType>();
+            return new OfTypeAsyncIterator<TResult>(source);
+        }
+
+        internal sealed class OfTypeAsyncIterator<TResult> : AsyncIterator<TResult>
+        {
+            private readonly IAsyncEnumerable<object> _source;
+            private IAsyncEnumerator<object> _enumerator;
+
+            public OfTypeAsyncIterator(IAsyncEnumerable<object> source)
+            {
+                Debug.Assert(source != null);
+
+                _source = source;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new OfTypeAsyncIterator<TResult>(_source);
+            }
+
+            public override async ValueTask DisposeAsync()
+            {
+                if (_enumerator != null)
+                {
+                    await _enumerator.DisposeAsync().ConfigureAwait(false);
+                    _enumerator = null;
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            protected override async ValueTask<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        _enumerator = _source.GetAsyncEnumerator(cancellationToken);
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        while (await _enumerator.MoveNextAsync().ConfigureAwait(false))
+                        {
+                            var item = _enumerator.Current;
+                            if (item is TResult res)
+                            {
+                                current = res;
+                                return true;
+                            }
+                        }
+
+                        await DisposeAsync().ConfigureAwait(false);
+                        break;
+                }
+
+                return false;
+            }
         }
     }
 }