Oren Novotny 9 年之前
父節點
當前提交
2aca3c8719
共有 2 個文件被更改,包括 307 次插入138 次删除
  1. 307 136
      Ix.NET/Source/System.Interactive.Async/Skip.cs
  2. 0 2
      Ix.NET/Source/Tests/AsyncTests.Single.cs

+ 307 - 136
Ix.NET/Source/System.Interactive.Async/Skip.cs

@@ -4,6 +4,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Linq;
 using System.Threading;
 using System.Threading.Tasks;
@@ -16,90 +17,40 @@ namespace System.Linq
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
-            if (count < 0)
-                throw new ArgumentOutOfRangeException(nameof(count));
 
-            return CreateEnumerable(
-                () =>
+            if (count <= 0)
+            {
+                // Return source if not actually skipping, but only if it's a type from here, to avoid
+                // issues if collections are used as keys or otherwise must not be aliased.
+                if (source is AsyncIterator<TSource>)
                 {
-                    var e = source.GetEnumerator();
-                    var n = count;
+                    return source;
+                }
 
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
+                count = 0;
+            }
 
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
-                        {
-                            var moveNext = await e.MoveNext(ct)
-                                                  .ConfigureAwait(false);
-                            if (n == 0)
-                            {
-                                return moveNext;
-                            }
-                            --n;
-                            if (!moveNext)
-                            {
-                                return false;
-                            }
-                            return await f(ct)
-                                       .ConfigureAwait(false);
-                        };
-
-                    return CreateEnumerator(
-                        ct => f(cts.Token),
-                        () => e.Current,
-                        d.Dispose,
-                        e
-                    );
-                });
+            return new SkipAsyncIterator<TSource>(source, count);
         }
 
         public static IAsyncEnumerable<TSource> SkipLast<TSource>(this IAsyncEnumerable<TSource> source, int count)
         {
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
-            if (count < 0)
-                throw new ArgumentOutOfRangeException(nameof(count));
 
-            return CreateEnumerable(
-                () =>
+            if (count <= 0)
+            {
+                // Return source if not actually skipping, but only if it's a type from here, to avoid
+                // issues if collections are used as keys or otherwise must not be aliased.
+                if (source is AsyncIterator<TSource>)
                 {
-                    var e = source.GetEnumerator();
-
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
+                    return source;
+                }
 
-                    var q = new Queue<TSource>();
-                    var current = default(TSource);
+                count = 0;
+            }
 
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
-                        {
-                            if (await e.MoveNext(ct)
-                                       .ConfigureAwait(false))
-                            {
-                                var item = e.Current;
-
-                                q.Enqueue(item);
-                                if (q.Count > count)
-                                {
-                                    current = q.Dequeue();
-                                    return true;
-                                }
-                                return await f(ct)
-                                           .ConfigureAwait(false);
-                            }
-                            return false;
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => current,
-                        d.Dispose,
-                        e
-                    );
-                });
+            return new SkipLastAsyncIterator<TSource>(source, count);
         }
 
         public static IAsyncEnumerable<TSource> SkipWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
@@ -109,88 +60,308 @@ namespace System.Linq
             if (predicate == null)
                 throw new ArgumentNullException(nameof(predicate));
 
-            return CreateEnumerable(
-                () =>
+            return new SkipWhileAsyncIterator<TSource>(source, predicate);
+        }
+
+        public static IAsyncEnumerable<TSource> SkipWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (predicate == null)
+                throw new ArgumentNullException(nameof(predicate));
+
+            return new SkipWhileWithIndexAsyncIterator<TSource>(source, predicate);
+        }
+
+        private sealed class SkipAsyncIterator<TSource> : AsyncIterator<TSource>
+        {
+            private readonly int count;
+            private readonly IAsyncEnumerable<TSource> source;
+            private int currentCount;
+            private IAsyncEnumerator<TSource> enumerator;
+
+            public SkipAsyncIterator(IAsyncEnumerable<TSource> source, int count)
+            {
+                this.source = source;
+                this.count = count;
+                currentCount = count;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new SkipAsyncIterator<TSource>(source, count);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
                 {
-                    var e = source.GetEnumerator();
-                    var skipping = true;
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+
+                base.Dispose();
+            }
+
 
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
 
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
+                        // skip elements as requested
+                        while (currentCount > 0 && await enumerator.MoveNext(cancellationToken)
+                                                                   .ConfigureAwait(false))
+                        {
+                            currentCount--;
+                        }
+                        if (currentCount <= 0)
                         {
-                            if (skipping)
+                            state = AsyncIteratorState.Iterating;
+                            goto case AsyncIteratorState.Iterating;
+                        }
+                        break;
+
+                    case AsyncIteratorState.Iterating:
+                        if (await enumerator.MoveNext(cancellationToken)
+                                            .ConfigureAwait(false))
+                        {
+                            current = enumerator.Current;
+                            return true;
+                        }
+
+                        break;
+                }
+
+                Dispose();
+                return false;
+            }
+        }
+
+        private sealed class SkipLastAsyncIterator<TSource> : AsyncIterator<TSource>
+        {
+            private readonly int count;
+            private readonly IAsyncEnumerable<TSource> source;
+            private IAsyncEnumerator<TSource> enumerator;
+            private Queue<TSource> queue;
+
+            public SkipLastAsyncIterator(IAsyncEnumerable<TSource> source, int count)
+            {
+                this.source = source;
+                this.count = count;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new SkipLastAsyncIterator<TSource>(source, count);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+                queue = null; // release the memory
+
+                base.Dispose();
+            }
+
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
+                        queue = new Queue<TSource>();
+
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+
+                    case AsyncIteratorState.Iterating:
+                        if (await enumerator.MoveNext(cancellationToken)
+                                            .ConfigureAwait(false))
+                        {
+                            var item = enumerator.Current;
+                            queue.Enqueue(item);
+                            if (queue.Count > count)
                             {
-                                if (await e.MoveNext(ct)
-                                           .ConfigureAwait(false))
-                                {
-                                    if (predicate(e.Current))
-                                        return await f(ct)
-                                                   .ConfigureAwait(false);
-                                    skipping = false;
-                                    return true;
-                                }
-                                return false;
+                                current = queue.Dequeue();
+                                return true;
                             }
-                            return await e.MoveNext(ct)
-                                          .ConfigureAwait(false);
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => e.Current,
-                        d.Dispose,
-                        e
-                    );
-                });
+                            goto case AsyncIteratorState.Iterating; // loop until either the await is false or we return an item
+                        }
+
+                        break;
+                }
+
+                Dispose();
+                return false;
+            }
         }
 
-        public static IAsyncEnumerable<TSource> SkipWhile<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
+        private sealed class SkipWhileAsyncIterator<TSource> : AsyncIterator<TSource>
         {
-            if (source == null)
-                throw new ArgumentNullException(nameof(source));
-            if (predicate == null)
-                throw new ArgumentNullException(nameof(predicate));
+            private readonly Func<TSource, bool> predicate;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private bool doMoveNext;
+            private IAsyncEnumerator<TSource> enumerator;
+
+            public SkipWhileAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
+            {
+                Debug.Assert(predicate != null);
+                Debug.Assert(source != null);
+
+                this.source = source;
+                this.predicate = predicate;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new SkipWhileAsyncIterator<TSource>(source, predicate);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
+
+                        // skip elements as requested
+                        while (await enumerator.MoveNext(cancellationToken)
+                                               .ConfigureAwait(false))
+                        {
+                            var element = enumerator.Current;
+                            if (!predicate(element))
+                            {
+                                doMoveNext = false;
+                                state = AsyncIteratorState.Iterating;
+                                goto case AsyncIteratorState.Iterating;
+                            }
+                        }
+                        break;
 
-            return CreateEnumerable(
-                () =>
+                    case AsyncIteratorState.Iterating:
+                        if (doMoveNext && await enumerator.MoveNext(cancellationToken)
+                                                          .ConfigureAwait(false))
+                        {
+                            current = enumerator.Current;
+                            return true;
+                        }
+                        if (!doMoveNext)
+                        {
+                            current = enumerator.Current;
+                            doMoveNext = true;
+                            return true;
+                        }
+
+                        break;
+                }
+
+                Dispose();
+                return false;
+            }
+        }
+
+        private sealed class SkipWhileWithIndexAsyncIterator<TSource> : AsyncIterator<TSource>
+        {
+            private readonly Func<TSource, int, bool> predicate;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private bool doMoveNext;
+            private IAsyncEnumerator<TSource> enumerator;
+            private int index;
+
+            public SkipWhileWithIndexAsyncIterator(IAsyncEnumerable<TSource> source, Func<TSource, int, bool> predicate)
+            {
+                Debug.Assert(predicate != null);
+                Debug.Assert(source != null);
+
+                this.source = source;
+                this.predicate = predicate;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new SkipWhileWithIndexAsyncIterator<TSource>(source, predicate);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
                 {
-                    var e = source.GetEnumerator();
-                    var skipping = true;
-                    var index = 0;
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+
+                base.Dispose();
+            }
 
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
+                        index = -1;
 
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
+                        // skip elements as requested
+                        while (await enumerator.MoveNext(cancellationToken)
+                                               .ConfigureAwait(false))
                         {
-                            if (skipping)
+                            checked
+                            {
+                                index++;
+                            }
+
+                            var element = enumerator.Current;
+                            if (!predicate(element, index))
                             {
-                                if (await e.MoveNext(ct)
-                                           .ConfigureAwait(false))
-                                {
-                                    if (predicate(e.Current, checked(index++)))
-                                        return await f(ct)
-                                                   .ConfigureAwait(false);
-                                    skipping = false;
-                                    return true;
-                                }
-                                return false;
+                                doMoveNext = false;
+                                state = AsyncIteratorState.Iterating;
+                                goto case AsyncIteratorState.Iterating;
                             }
-                            return await e.MoveNext(ct)
-                                          .ConfigureAwait(false);
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => e.Current,
-                        d.Dispose,
-                        e
-                    );
-                });
+                        }
+                        break;
+
+                    case AsyncIteratorState.Iterating:
+                        if (doMoveNext && await enumerator.MoveNext(cancellationToken)
+                                                          .ConfigureAwait(false))
+                        {
+                            current = enumerator.Current;
+                            return true;
+                        }
+                        if (!doMoveNext)
+                        {
+                            current = enumerator.Current;
+                            doMoveNext = true;
+                            return true;
+                        }
+
+                        break;
+                }
+
+                Dispose();
+                return false;
+            }
         }
     }
 }

+ 0 - 2
Ix.NET/Source/Tests/AsyncTests.Single.cs

@@ -880,7 +880,6 @@ namespace Tests
         public void Skip_Null()
         {
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Skip<int>(null, 5));
-            AssertThrows<ArgumentOutOfRangeException>(() => AsyncEnumerable.Skip<int>(AsyncEnumerable.Return(42), -1));
         }
 
         [Fact]
@@ -2464,7 +2463,6 @@ namespace Tests
         public void SkipLast_Null()
         {
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.SkipLast(default(IAsyncEnumerable<int>), 5));
-            AssertThrows<ArgumentOutOfRangeException>(() => AsyncEnumerable.SkipLast(AsyncEnumerable.Return(42), -1));
         }
 
         [Fact]