瀏覽代碼

Add where/select optimizations

Oren Novotny 9 年之前
父節點
當前提交
d962aebc40

+ 106 - 0
Ix.NET/Source/System.Interactive.Async/AsyncIterator.cs

@@ -0,0 +1,106 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace System.Linq
+{
+    public static partial class AsyncEnumerable
+    {
+        internal abstract class AsyncIterator<TSource> : IAsyncEnumerable<TSource>, IAsyncEnumerator<TSource>
+        {
+            public enum State
+            {
+                New = 0,
+                Allocated = 1,
+                Iterating = 2,
+                Disposed = -1,
+            }
+
+            private readonly int threadId;
+            internal State state = State.New;
+            internal TSource current;
+            private CancellationTokenSource cancellationTokenSource;
+
+            protected AsyncIterator()
+            {
+                threadId = Environment.CurrentManagedThreadId;
+            }
+
+            public abstract AsyncIterator<TSource> Clone();
+
+            public IAsyncEnumerator<TSource> GetEnumerator()
+            {
+                var enumerator = state == State.New && threadId == Environment.CurrentManagedThreadId ? this : Clone();
+
+                enumerator.state = State.Allocated;
+                enumerator.cancellationTokenSource = new CancellationTokenSource();
+                return enumerator;
+            }
+
+            
+            public virtual void Dispose()
+            {
+                if (!cancellationTokenSource.IsCancellationRequested)
+                {
+                    cancellationTokenSource.Cancel();
+                }
+                cancellationTokenSource.Dispose();
+                current = default(TSource);
+                state = State.Disposed;
+            }
+
+            private void Cancel()
+            {
+                if (!cancellationTokenSource.IsCancellationRequested)
+                {
+                    cancellationTokenSource.Cancel();
+                }
+                Dispose();
+                Debug.WriteLine("Canceled");
+            }
+
+            public TSource Current => current;
+
+            public async Task<bool> MoveNext(CancellationToken cancellationToken)
+            {
+              //  using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, cancellationTokenSource.Token))
+                using (cancellationToken.Register(Cancel))
+
+                {
+                    try
+                    {
+                        var result = await MoveNextCore(cancellationTokenSource.Token).ConfigureAwait(false);
+
+                     //   cts.Dispose();
+                        //if (cts.IsCancellationRequested)
+                        //{
+                        //    Dispose();
+                        //}
+
+                        return result;
+                    }
+                    catch
+                    {
+                        Dispose();
+                        throw;
+                    }
+                }
+            }
+
+            public abstract Task<bool> MoveNextCore(CancellationToken cancellationToken);
+
+            public virtual IAsyncEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector)
+            {
+                return new SelectEnumerableAsyncIterator<TSource, TResult>(this, selector);
+            }
+
+            public virtual IAsyncEnumerable<TSource> Where(Func<TSource, bool> predicate)
+            {
+                return new WhereEnumerableAsyncIterator<TSource>(this, predicate);
+            }
+        }
+    }
+}

+ 80 - 23
Ix.NET/Source/System.Interactive.Async/Select.cs

@@ -4,7 +4,9 @@
 
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Linq;
+using System.Threading;
 using System.Threading.Tasks;
 
 namespace System.Linq
@@ -18,31 +20,16 @@ namespace System.Linq
             if (selector == null)
                 throw new ArgumentNullException(nameof(selector));
 
-            return CreateEnumerable(
-                () =>
-                {
-                    var e = source.GetEnumerator();
-                    var current = default(TResult);
 
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
+            var iterator = source as AsyncIterator<TSource>;
+            if (iterator != null)
+            {
+                return iterator.Select(selector);
+            }
 
-                    return CreateEnumerator(
-                        async ct =>
-                        {
-                            if (await e.MoveNext(cts.Token)
-                                       .ConfigureAwait(false))
-                            {
-                                current = selector(e.Current);
-                                return true;
-                            }
-                            return false;
-                        },
-                        () => current,
-                        d.Dispose,
-                        e
-                    );
-                });
+            // TODO: Can we add optimizations for IList or anything else here?
+
+            return new SelectEnumerableAsyncIterator<TSource, TResult>(source, selector);
         }
 
         public static IAsyncEnumerable<TResult> Select<TSource, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, int, TResult> selector)
@@ -79,5 +66,75 @@ namespace System.Linq
                     );
                 });
         }
+
+        private static Func<TSource, TResult> CombineSelectors<TSource, TMiddle, TResult>(Func<TSource, TMiddle> selector1, Func<TMiddle, TResult> selector2)
+
+        {
+
+            return x => selector2(selector1(x));
+
+        }
+
+        internal sealed class SelectEnumerableAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
+        {
+            private readonly IAsyncEnumerable<TSource> source;
+            private readonly Func<TSource, TResult> selector;
+            private IAsyncEnumerator<TSource> enumerator;
+
+            public SelectEnumerableAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, TResult> selector)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(selector != null);
+
+                this.source = source;
+                this.selector = selector;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new SelectEnumerableAsyncIterator<TSource, TResult>(source, selector);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+   
+                base.Dispose();
+            }
+
+            public override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case State.Allocated:
+                        enumerator = source.GetEnumerator();
+                        state = State.Iterating;
+                        goto case State.Iterating;
+
+                    case State.Iterating:
+                        if (await enumerator.MoveNext(cancellationToken)
+                                            .ConfigureAwait(false))
+                        {
+                            current = selector(enumerator.Current);
+                            return true;
+
+                        }
+
+                        Dispose();
+                        break;
+                }
+
+                return false;
+            }
+
+            public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, TResult1> selector)
+            {
+                return new SelectEnumerableAsyncIterator<TSource, TResult1>(source, CombineSelectors(this.selector, selector));
+            }
+        }
     }
 }

+ 153 - 28
Ix.NET/Source/System.Interactive.Async/Where.cs

@@ -4,6 +4,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Linq;
 using System.Threading;
 using System.Threading.Tasks;
@@ -19,35 +20,14 @@ namespace System.Linq
             if (predicate == null)
                 throw new ArgumentNullException(nameof(predicate));
 
-            return CreateEnumerable(
-                () =>
-                {
-                    var e = source.GetEnumerator();
+            var iterator = source as AsyncIterator<TSource>;
+            if (iterator != null)
+            {
+                return iterator.Where(predicate);
+            }
 
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
-
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
-                        {
-                            if (await e.MoveNext(ct)
-                                       .ConfigureAwait(false))
-                            {
-                                if (predicate(e.Current))
-                                    return true;
-                                return await f(ct)
-                                           .ConfigureAwait(false);
-                            }
-                            return false;
-                        };
-
-                    return CreateEnumerator(
-                        ct => f(cts.Token),
-                        () => e.Current,
-                        d.Dispose,
-                        e
-                    );
-                });
+            // TODO: Can we add array/list optimizations here, does it make sense?
+            return new WhereEnumerableAsyncIterator<TSource>(source, predicate);
         }
 
         public static IAsyncEnumerable<TSource> Where<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
@@ -88,5 +68,150 @@ namespace System.Linq
                     );
                 });
         }
+
+        private static Func<TSource, bool> CombinePredicates<TSource>(Func<TSource, bool> predicate1, Func<TSource, bool> predicate2)
+
+        {
+
+            return x => predicate1(x) && predicate2(x);
+
+        }
+
+        internal sealed class WhereEnumerableAsyncIterator<TSource> : AsyncIterator<TSource>
+        {
+            private readonly IAsyncEnumerable<TSource> source;
+            private readonly Func<TSource, bool> predicate;
+            private IAsyncEnumerator<TSource> enumerator;
+
+            public WhereEnumerableAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(predicate != null);
+
+                this.source = source;
+                this.predicate = predicate;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new WhereEnumerableAsyncIterator<TSource>(source, predicate);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+                base.Dispose();
+            }
+
+            public override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case State.Allocated:
+                        enumerator = source.GetEnumerator();
+                        state = State.Iterating;
+                        goto case State.Iterating;
+
+                    case State.Iterating:
+                        while (await enumerator.MoveNext(cancellationToken)
+                                               .ConfigureAwait(false))
+                        {
+                            var item = enumerator.Current;
+                            if (predicate(item))
+                            {
+                                current = item;
+                                return true;
+                            }
+                        }
+
+                        Dispose();
+                        break;
+                }
+
+                return false;
+            }
+
+            public override IAsyncEnumerable<TSource> Where(Func<TSource, bool> predicate)
+            {
+                return new WhereEnumerableAsyncIterator<TSource>(source, CombinePredicates(this.predicate, predicate));
+            }
+
+            public override IAsyncEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector)
+            {
+                return new WhereSelectEnumerableAsyncIterator<TSource, TResult>(source, predicate, selector);
+            }
+        }
+
+        internal sealed class WhereSelectEnumerableAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
+        {
+            private readonly IAsyncEnumerable<TSource> source;
+            private readonly Func<TSource, bool> predicate;
+            private readonly Func<TSource, TResult> selector;
+            private IAsyncEnumerator<TSource> enumerator;
+
+            public WhereSelectEnumerableAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, Func<TSource, TResult> selector)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(predicate != null);
+                Debug.Assert(selector != null);
+
+                this.source = source;
+                this.predicate = predicate;
+                this.selector = selector;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new WhereSelectEnumerableAsyncIterator<TSource, TResult>(source, predicate, selector);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+
+                base.Dispose();
+            }
+
+            public override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case State.Allocated:
+                        enumerator = source.GetEnumerator();
+                        state = State.Iterating;
+                        goto case State.Iterating;
+
+                    case State.Iterating:
+                        while (await enumerator.MoveNext(cancellationToken)
+                                               .ConfigureAwait(false))
+                        {
+                            var item = enumerator.Current;
+                            if (predicate(item))
+                            {
+                                current = selector(item);
+                                return true;
+                            }
+                        }
+
+                        Dispose();
+                        break;
+                }
+
+                return false;
+            }
+
+            public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, TResult1> selector)
+            {
+                return new WhereSelectEnumerableAsyncIterator<TSource, TResult1>(source, predicate, CombineSelectors(this.selector, selector));
+            }
+        }
     }
 }