Browse Source

Optimize repeat

Oren Novotny 9 years ago
parent
commit
d70e689c1a

+ 92 - 87
Ix.NET/Source/System.Interactive.Async/Repeat.cs

@@ -23,15 +23,7 @@ namespace System.Linq
 
         public static IAsyncEnumerable<TResult> Repeat<TResult>(TResult element)
         {
-            return CreateEnumerable(
-                () =>
-                {
-                    return CreateEnumerator(
-                        ct => TaskExt.True,
-                        () => element,
-                        () => { }
-                    );
-                });
+            return new RepeatElementAsyncIterator<TResult>(element);
         }
 
         public static IAsyncEnumerable<TSource> Repeat<TSource>(this IAsyncEnumerable<TSource> source, int count)
@@ -41,50 +33,7 @@ namespace System.Linq
             if (count < 0)
                 throw new ArgumentOutOfRangeException(nameof(count));
 
-            return CreateEnumerable(
-                () =>
-                {
-                    var e = default(IAsyncEnumerator<TSource>);
-                    var a = new AssignableDisposable();
-                    var n = count;
-                    var current = default(TSource);
-
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, a);
-
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
-                        {
-                            if (e == null)
-                            {
-                                if (n-- == 0)
-                                {
-                                    return false;
-                                }
-
-                                e = source.GetEnumerator();
-
-                                a.Disposable = e;
-                            }
-
-                            if (await e.MoveNext(ct)
-                                       .ConfigureAwait(false))
-                            {
-                                current = e.Current;
-                                return true;
-                            }
-                            e = null;
-                            return await f(ct)
-                                       .ConfigureAwait(false);
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => current,
-                        d.Dispose,
-                        e
-                    );
-                });
+            return new RepeatSequenceAsyncIterator<TSource>(source, count);
         }
 
         public static IAsyncEnumerable<TSource> Repeat<TSource>(this IAsyncEnumerable<TSource> source)
@@ -92,44 +41,100 @@ namespace System.Linq
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
 
-            return CreateEnumerable(
-                () =>
+            return new RepeatSequenceAsyncIterator<TSource>(source, -1);
+        }
+
+        private sealed class RepeatElementAsyncIterator<TResult> : AsyncIterator<TResult>
+        {
+            private readonly TResult element;
+
+            public RepeatElementAsyncIterator(TResult element)
+            {
+                this.element = element;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new RepeatElementAsyncIterator<TResult>(element);
+            }
+
+            protected override Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                current = element;
+                return TaskExt.True;
+            }
+        }
+
+        private sealed class RepeatSequenceAsyncIterator<TSource> : AsyncIterator<TSource>
+        {
+            private readonly int count;
+            private readonly bool isInfinite;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private int currentCount;
+            private IAsyncEnumerator<TSource> enumerator;
+
+            public RepeatSequenceAsyncIterator(IAsyncEnumerable<TSource> source, int count)
+            {
+                this.source = source;
+                this.count = count;
+                isInfinite = count < 0;
+                currentCount = count;
+            }
+
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new RepeatSequenceAsyncIterator<TSource>(source, count);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
                 {
-                    var e = default(IAsyncEnumerator<TSource>);
-                    var a = new AssignableDisposable();
-                    var current = default(TSource);
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
 
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, a);
+                base.Dispose();
+            }
 
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+
+                        if (enumerator != null)
+                        {
+                            enumerator.Dispose();
+                            enumerator = null;
+                        }
+
+                        if (!isInfinite && currentCount-- == 0)
+                            break;
+
+                        enumerator = source.GetEnumerator();
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        if (await enumerator.MoveNext(cancellationToken)
+                                            .ConfigureAwait(false))
                         {
-                            if (e == null)
-                            {
-                                e = source.GetEnumerator();
-
-                                a.Disposable = e;
-                            }
-
-                            if (await e.MoveNext(ct)
-                                       .ConfigureAwait(false))
-                            {
-                                current = e.Current;
-                                return true;
-                            }
-                            e = null;
-                            return await f(ct)
-                                       .ConfigureAwait(false);
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => current,
-                        d.Dispose,
-                        e
-                    );
-                });
+                            current = enumerator.Current;
+                            return true;
+                        }
+
+                        goto case AsyncIteratorState.Allocated;
+                        
+                }
+
+
+                Dispose();
+
+                return false;
+            }
         }
     }
 }

+ 11 - 0
Ix.NET/Source/Tests/AsyncTests.Single.cs

@@ -2024,6 +2024,17 @@ namespace Tests
             Assert.Equal(3, i);
         }
 
+        [Fact]
+        public void RepeatSeq0()
+        {
+            var i = 0;
+            var xs = RepeatXs(() => i++).ToAsyncEnumerable().Repeat(0);
+
+            var e = xs.GetEnumerator();
+
+            NoNext(e);
+        }
+
         static IEnumerable<int> RepeatXs(Action started)
         {
             started();