瀏覽代碼

Fixes to tests

Oren Novotny 9 年之前
父節點
當前提交
0841a4e47d

+ 28 - 7
Ix.NET/Source/System.Interactive.Async/AsyncIterator.cs

@@ -24,6 +24,7 @@ namespace System.Linq
             internal TSource current;
             private CancellationTokenSource cancellationTokenSource;
             private List<CancellationTokenRegistration> moveNextRegistrations;
+            private bool currentIsInvalid = true;
 
             protected AsyncIterator()
             {
@@ -50,21 +51,38 @@ namespace System.Linq
                     cancellationTokenSource.Cancel();
                 }
                 cancellationTokenSource.Dispose();
-                foreach (var r in moveNextRegistrations)
-                {
-                    r.Dispose();
-                }
-                moveNextRegistrations.Clear();
+
                 current = default(TSource);
                 state = State.Disposed;
+
+                var toClean = moveNextRegistrations?.ToList();
+                moveNextRegistrations = null;
+                if (toClean != null)
+                {
+                    foreach (var r in toClean)
+                    {
+                        r.Dispose();
+                    }
+                    toClean.Clear();
+                }
             }
 
-            public TSource Current => current;
+            public TSource Current
+            {
+                get
+                {
+                    if (currentIsInvalid)
+                        throw new InvalidOperationException("Enumerator is in an invalid state");
+                    return current;
+                }
+            }
 
             public async Task<bool> MoveNext(CancellationToken cancellationToken)
             {
                 if (state == State.Disposed)
+                {
                     return false;
+                }
 
                 // We keep these because cancelling any of these must trigger dispose of the iterator
                 moveNextRegistrations.Add(cancellationToken.Register(Dispose));
@@ -74,11 +92,14 @@ namespace System.Linq
                     try
                     {
                         var result = await MoveNextCore(cts.Token).ConfigureAwait(false);
-                        
+
+                        currentIsInvalid = !result; // if move next is false, invalid otherwise valid
+
                         return result;
                     }
                     catch
                     {
+                        currentIsInvalid = true;
                         Dispose();
                         throw;
                     }

+ 49 - 44
Ix.NET/Source/System.Interactive.Async/Create.cs

@@ -19,40 +19,19 @@ namespace System.Linq
 
         public static IAsyncEnumerator<T> CreateEnumerator<T>(Func<CancellationToken, Task<bool>> moveNext, Func<T> current, Action dispose)
         {
-            return new AnonymousAsyncEnumerator<T>(moveNext, current, dispose);
+            return new AnonymousAsyncIterator<T>(moveNext, current, dispose, null);
         }
 
         private static IAsyncEnumerator<T> CreateEnumerator<T>(Func<CancellationToken, Task<bool>> moveNext, Func<T> current,
                                                                Action dispose, IDisposable enumerator)
         {
-            return CreateEnumerator(
-                async ct =>
-                {
-                    using (ct.Register(dispose))
-                    {
-                        try
-                        {
-                            var result = await moveNext(ct)
-                                             .ConfigureAwait(false);
-                            if (!result)
-                            {
-                                enumerator?.Dispose();
-                            }
-                            return result;
-                        }
-                        catch
-                        {
-                            enumerator?.Dispose();
-                            throw;
-                        }
-                    }
-                }, current, dispose);
+            return new AnonymousAsyncIterator<T>(moveNext, current, dispose, enumerator);
         }
 
         private static IAsyncEnumerator<T> CreateEnumerator<T>(Func<CancellationToken, TaskCompletionSource<bool>, Task<bool>> moveNext, Func<T> current, Action dispose)
         {
             var self = default(IAsyncEnumerator<T>);
-            self = new AnonymousAsyncEnumerator<T>(
+            self = new AnonymousAsyncIterator<T>(
                 async ct =>
                 {
                     var tcs = new TaskCompletionSource<bool>();
@@ -71,7 +50,8 @@ namespace System.Linq
                     }
                 },
                 current,
-                dispose
+                dispose, 
+                null
             );
             return self;
         }
@@ -93,37 +73,62 @@ namespace System.Linq
             }
         }
 
-        private class AnonymousAsyncEnumerator<T> : IAsyncEnumerator<T>
+        private sealed class AnonymousAsyncIterator<T> : AsyncIterator<T>
         {
-            private readonly Func<T> _current;
-            private readonly Action _dispose;
-            private readonly Func<CancellationToken, Task<bool>> _moveNext;
-            private bool _disposed;
+            private readonly Func<T> currentFunc;
+            private readonly Action dispose;
+            private IDisposable enumerator;
+            private readonly Func<CancellationToken, Task<bool>> moveNext;
+
 
-            public AnonymousAsyncEnumerator(Func<CancellationToken, Task<bool>> moveNext, Func<T> current, Action dispose)
+            public AnonymousAsyncIterator(Func<CancellationToken, Task<bool>> moveNext, Func<T> currentFunc, Action dispose, IDisposable enumerator)
             {
-                _moveNext = moveNext;
-                _current = current;
-                _dispose = dispose;
+                this.moveNext = moveNext;
+                this.currentFunc = currentFunc;
+                this.dispose = dispose;
+                this.enumerator = enumerator;
+
+                // Explicit call to initialize enumerator mode
+                GetEnumerator();
             }
 
-            public Task<bool> MoveNext(CancellationToken cancellationToken)
+            public override AsyncIterator<T> Clone()
             {
-                if (_disposed)
-                    return TaskExt.False;
-
-                return _moveNext(cancellationToken);
+                throw new NotSupportedException("Iterator only");
             }
 
-            public T Current => _current();
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+                dispose?.Invoke();
+
+                base.Dispose();
+            }
 
-            public void Dispose()
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
             {
-                if (!_disposed)
+                switch (state)
                 {
-                    _disposed = true;
-                    _dispose();
+                    case State.Allocated:
+                        state = State.Iterating;
+                        goto case State.Iterating;
+
+                    case State.Iterating:
+                        if (await moveNext(cancellationToken).ConfigureAwait(false))
+                        {
+                            current = currentFunc();
+                            return true;
+                        }
+
+                        Dispose();
+                        break;
                 }
+
+                return false;
             }
         }
     }

+ 1 - 1
Ix.NET/Source/System.Interactive.Async/Using.cs

@@ -64,7 +64,7 @@ namespace System.Linq
                         },
                         () => current,
                         d.Dispose,
-                        d
+                        null
                     );
                 });
         }

+ 3 - 0
Ix.NET/Source/Tests/AsyncTests.Single.cs

@@ -1915,6 +1915,9 @@ namespace Tests
 
             e.Dispose();
 
+           // TODO: Do the internal iterators really get cleaned up?
+           // look once this group by method has been updated
+
             HasNext(g1e, 'd');
             HasNext(g1e, 'g');
             HasNext(g1e, 'j');