فهرست منبع

Where/Select with Index

Oren Novotny 9 سال پیش
والد
کامیت
25f995cb1d
2فایلهای تغییر یافته به همراه152 افزوده شده و 84 حذف شده
  1. 75 41
      Ix.NET/Source/System.Interactive.Async/Select.cs
  2. 77 43
      Ix.NET/Source/System.Interactive.Async/Where.cs

+ 75 - 41
Ix.NET/Source/System.Interactive.Async/Select.cs

@@ -30,7 +30,7 @@ namespace System.Linq
             var ilist = source as IList<TSource>;
             if (ilist != null)
             {
-               return new SelectIListIterator<TSource, TResult>(ilist, selector);
+                return new SelectIListIterator<TSource, TResult>(ilist, selector);
             }
 
             return new SelectEnumerableAsyncIterator<TSource, TResult>(source, selector);
@@ -43,46 +43,19 @@ namespace System.Linq
             if (selector == null)
                 throw new ArgumentNullException(nameof(selector));
 
-            return CreateEnumerable(
-                () =>
-                {
-                    var e = source.GetEnumerator();
-                    var current = default(TResult);
-                    var index = 0;
-
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
-
-                    return CreateEnumerator(
-                        async ct =>
-                        {
-                            if (await e.MoveNext(cts.Token)
-                                       .ConfigureAwait(false))
-                            {
-                                current = selector(e.Current, checked(index++));
-                                return true;
-                            }
-                            return false;
-                        },
-                        () => current,
-                        d.Dispose,
-                        e
-                    );
-                });
+            return new SelectEnumerableWithIndexAsyncIterator<TSource, TResult>(source, selector);
         }
 
         private static Func<TSource, TResult> CombineSelectors<TSource, TMiddle, TResult>(Func<TSource, TMiddle> selector1, Func<TMiddle, TResult> selector2)
 
         {
-
             return x => selector2(selector1(x));
-
         }
 
         internal sealed class SelectEnumerableAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
         {
-            private readonly IAsyncEnumerable<TSource> source;
             private readonly Func<TSource, TResult> selector;
+            private readonly IAsyncEnumerable<TSource> source;
             private IAsyncEnumerator<TSource> enumerator;
 
             public SelectEnumerableAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, TResult> selector)
@@ -106,10 +79,15 @@ namespace System.Linq
                     enumerator.Dispose();
                     enumerator = null;
                 }
-   
+
                 base.Dispose();
             }
 
+            public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, TResult1> selector)
+            {
+                return new SelectEnumerableAsyncIterator<TSource, TResult1>(source, CombineSelectors(this.selector, selector));
+            }
+
             protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
             {
                 switch (state)
@@ -125,7 +103,6 @@ namespace System.Linq
                         {
                             current = selector(enumerator.Current);
                             return true;
-
                         }
 
                         Dispose();
@@ -134,17 +111,74 @@ namespace System.Linq
 
                 return false;
             }
+        }
 
-            public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, TResult1> selector)
+        internal sealed class SelectEnumerableWithIndexAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
+        {
+            private readonly Func<TSource, int, TResult> selector;
+            private readonly IAsyncEnumerable<TSource> source;
+            private IAsyncEnumerator<TSource> enumerator;
+            private int index;
+
+            public SelectEnumerableWithIndexAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, int, TResult> selector)
             {
-                return new SelectEnumerableAsyncIterator<TSource, TResult1>(source, CombineSelectors(this.selector, selector));
+                Debug.Assert(source != null);
+                Debug.Assert(selector != null);
+
+                this.source = source;
+                this.selector = selector;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new SelectEnumerableWithIndexAsyncIterator<TSource, TResult>(source, selector);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
+                        index = -1;
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        if (await enumerator.MoveNext(cancellationToken)
+                                            .ConfigureAwait(false))
+                        {
+                            checked
+                            {
+                                index++;
+                            }
+                            current = selector(enumerator.Current, index);
+                            return true;
+                        }
+
+                        Dispose();
+                        break;
+                }
+
+                return false;
             }
         }
 
         internal sealed class SelectIListIterator<TSource, TResult> : AsyncIterator<TResult>
         {
-            private readonly IList<TSource> source;
             private readonly Func<TSource, TResult> selector;
+            private readonly IList<TSource> source;
             private IEnumerator<TSource> enumerator;
 
             public SelectIListIterator(IList<TSource> source, Func<TSource, TResult> selector)
@@ -172,6 +206,11 @@ namespace System.Linq
                 base.Dispose();
             }
 
+            public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, TResult1> selector)
+            {
+                return new SelectIListIterator<TSource, TResult1>(source, CombineSelectors(this.selector, selector));
+            }
+
             protected override Task<bool> MoveNextCore(CancellationToken cancellationToken)
             {
                 switch (state)
@@ -182,7 +221,7 @@ namespace System.Linq
                         goto case AsyncIteratorState.Iterating;
 
                     case AsyncIteratorState.Iterating:
-                        if ( enumerator.MoveNext())
+                        if (enumerator.MoveNext())
                         {
                             current = selector(enumerator.Current);
                             return Task.FromResult(true);
@@ -194,11 +233,6 @@ namespace System.Linq
 
                 return Task.FromResult(false);
             }
-
-            public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, TResult1> selector)
-            {
-                return new SelectIListIterator<TSource, TResult1>(source, CombineSelectors(this.selector, selector));
-            }
         }
     }
 }

+ 77 - 43
Ix.NET/Source/System.Interactive.Async/Where.cs

@@ -37,50 +37,19 @@ namespace System.Linq
             if (predicate == null)
                 throw new ArgumentNullException(nameof(predicate));
 
-            return CreateEnumerable(
-                () =>
-                {
-                    var e = source.GetEnumerator();
-                    var index = 0;
-
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
-
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
-                        {
-                            if (await e.MoveNext(ct)
-                                       .ConfigureAwait(false))
-                            {
-                                if (predicate(e.Current, checked(index++)))
-                                    return true;
-                                return await f(ct)
-                                           .ConfigureAwait(false);
-                            }
-                            return false;
-                        };
-
-                    return CreateEnumerator(
-                        ct => f(cts.Token),
-                        () => e.Current,
-                        d.Dispose,
-                        e
-                    );
-                });
+            return new WhereEnumerableWithIndexAsyncIterator<TSource>(source, predicate);
         }
 
         private static Func<TSource, bool> CombinePredicates<TSource>(Func<TSource, bool> predicate1, Func<TSource, bool> predicate2)
 
         {
-
             return x => predicate1(x) && predicate2(x);
-
         }
 
         internal sealed class WhereEnumerableAsyncIterator<TSource> : AsyncIterator<TSource>
         {
-            private readonly IAsyncEnumerable<TSource> source;
             private readonly Func<TSource, bool> predicate;
+            private readonly IAsyncEnumerable<TSource> source;
             private IAsyncEnumerator<TSource> enumerator;
 
             public WhereEnumerableAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
@@ -107,6 +76,16 @@ namespace System.Linq
                 base.Dispose();
             }
 
+            public override IAsyncEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector)
+            {
+                return new WhereSelectEnumerableAsyncIterator<TSource, TResult>(source, predicate, selector);
+            }
+
+            public override IAsyncEnumerable<TSource> Where(Func<TSource, bool> predicate)
+            {
+                return new WhereEnumerableAsyncIterator<TSource>(source, CombinePredicates(this.predicate, predicate));
+            }
+
             protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
             {
                 switch (state)
@@ -134,23 +113,78 @@ namespace System.Linq
 
                 return false;
             }
+        }
 
-            public override IAsyncEnumerable<TSource> Where(Func<TSource, bool> predicate)
+        internal sealed class WhereEnumerableWithIndexAsyncIterator<TSource> : AsyncIterator<TSource>
+        {
+            private readonly Func<TSource, int, bool> predicate;
+            private readonly IAsyncEnumerable<TSource> source;
+            private IAsyncEnumerator<TSource> enumerator;
+            private int index;
+
+            public WhereEnumerableWithIndexAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
             {
-                return new WhereEnumerableAsyncIterator<TSource>(source, CombinePredicates(this.predicate, predicate));
+                Debug.Assert(source != null);
+                Debug.Assert(predicate != null);
+
+                this.source = source;
+                this.predicate = predicate;
             }
 
-            public override IAsyncEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector)
+            public override AsyncIterator<TSource> Clone()
             {
-                return new WhereSelectEnumerableAsyncIterator<TSource, TResult>(source, predicate, selector);
+                return new WhereEnumerableWithIndexAsyncIterator<TSource>(source, predicate);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
+                        index = -1;
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        while (await enumerator.MoveNext(cancellationToken)
+                                               .ConfigureAwait(false))
+                        {
+                            checked
+                            {
+                                index++;
+                            }
+                            var item = enumerator.Current;
+                            if (predicate(item, index))
+                            {
+                                current = item;
+                                return true;
+                            }
+                        }
+
+                        Dispose();
+                        break;
+                }
+
+                return false;
             }
         }
 
         internal sealed class WhereSelectEnumerableAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
         {
-            private readonly IAsyncEnumerable<TSource> source;
             private readonly Func<TSource, bool> predicate;
             private readonly Func<TSource, TResult> selector;
+            private readonly IAsyncEnumerable<TSource> source;
             private IAsyncEnumerator<TSource> enumerator;
 
             public WhereSelectEnumerableAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, Func<TSource, TResult> selector)
@@ -180,6 +214,11 @@ namespace System.Linq
                 base.Dispose();
             }
 
+            public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, TResult1> selector)
+            {
+                return new WhereSelectEnumerableAsyncIterator<TSource, TResult1>(source, predicate, CombineSelectors(this.selector, selector));
+            }
+
             protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
             {
                 switch (state)
@@ -207,11 +246,6 @@ namespace System.Linq
 
                 return false;
             }
-
-            public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, TResult1> selector)
-            {
-                return new WhereSelectEnumerableAsyncIterator<TSource, TResult1>(source, predicate, CombineSelectors(this.selector, selector));
-            }
         }
     }
 }