Ver Fonte

Making Using behavior consistent.

Bart De Smet há 7 anos atrás
pai
commit
f7a056b9c0

+ 29 - 8
Ix.NET/Source/System.Interactive.Async.Tests/System/Linq/Operators/Using.cs

@@ -35,9 +35,12 @@ namespace Tests
             );
 
             Assert.Equal(0, i);
+            Assert.Equal(0, d);
 
             var e = xs.GetAsyncEnumerator();
-            Assert.Equal(1, i);
+
+            Assert.Equal(0, i);
+            Assert.Equal(0, d);
         }
 
         [Fact]
@@ -56,16 +59,20 @@ namespace Tests
             );
 
             Assert.Equal(0, i);
+            Assert.Equal(0, d);
 
             var e = xs.GetAsyncEnumerator();
-            Assert.Equal(1, i);
+            Assert.Equal(0, i);
+            Assert.Equal(0, d);
 
             await e.DisposeAsync();
-            Assert.Equal(1, d);
+
+            Assert.Equal(0, i);
+            Assert.Equal(0, d);
         }
 
         [Fact]
-        public void Using3()
+        public async Task Using3()
         {
             var ex = new Exception("Bang!");
             var i = 0;
@@ -81,10 +88,17 @@ namespace Tests
             );
 
             Assert.Equal(0, i);
+            Assert.Equal(0, d);
+
+            var e = xs.GetAsyncEnumerator();
+
+            Assert.Equal(0, i);
+            Assert.Equal(0, d);
 
-            AssertThrows<Exception>(() => xs.GetAsyncEnumerator(), ex_ => ex_ == ex);
+            await e.DisposeAsync();
 
-            Assert.Equal(1, d);
+            Assert.Equal(0, i);
+            Assert.Equal(0, d);
         }
 
         [Fact]
@@ -105,9 +119,13 @@ namespace Tests
             Assert.Equal(0, i);
 
             var e = xs.GetAsyncEnumerator();
-            Assert.Equal(1, i);
+
+            Assert.Equal(0, i);
 
             HasNext(e, 42);
+
+            Assert.Equal(1, i);
+
             NoNext(e);
 
             Assert.True(disposed.Task.Result);
@@ -132,10 +150,13 @@ namespace Tests
             Assert.Equal(0, i);
 
             var e = xs.GetAsyncEnumerator();
-            Assert.Equal(1, i);
+
+            Assert.Equal(0, i);
 
             AssertThrows(() => e.MoveNextAsync().Wait(WaitTimeoutMs), SingleInnerExceptionMatches(ex));
 
+            Assert.Equal(1, i);
+
             Assert.True(disposed.Task.Result);
         }
 

+ 0 - 14
Ix.NET/Source/System.Interactive.Async/AsyncIterator.cs

@@ -32,16 +32,6 @@ namespace System.Linq
             enumerator.state = AsyncIteratorState.Allocated;
             enumerator.cancellationToken = cancellationToken;
 
-            try
-            {
-                enumerator.OnGetEnumerator(cancellationToken);
-            }
-            catch
-            {
-                enumerator.DisposeAsync(); // REVIEW: fire-and-forget?
-                throw;
-            }
-
             return enumerator;
         }
 
@@ -94,10 +84,6 @@ namespace System.Linq
         public abstract AsyncIterator<TSource> Clone();
 
         protected abstract ValueTask<bool> MoveNextCore(CancellationToken cancellationToken);
-
-        protected virtual void OnGetEnumerator(CancellationToken cancellationToken)
-        {
-        }
     }
 
     internal enum AsyncIteratorState

+ 6 - 11
Ix.NET/Source/System.Interactive.Async/System/Linq/Operators/Using.cs

@@ -73,10 +73,15 @@ namespace System.Linq
 
             protected override async ValueTask<bool> MoveNextCore(CancellationToken cancellationToken)
             {
+                // NB: Earlier behavior of this operator was more eager, causing the resource factory to be called upon calling
+                //     GetAsyncEnumerator. This is inconsistent with asynchronous "using" and with a C# 8.0 async iterator with
+                //     a using statement inside, so this logic got moved to MoveNextAsync instead.
+
                 switch (state)
                 {
                     case AsyncIteratorState.Allocated:
-                        _enumerator = _enumerable.GetAsyncEnumerator(cancellationToken);
+                        _resource = _resourceFactory();
+                        _enumerator = _enumerableFactory(_resource).GetAsyncEnumerator(cancellationToken);
                         state = AsyncIteratorState.Iterating;
                         goto case AsyncIteratorState.Iterating;
 
@@ -93,16 +98,6 @@ namespace System.Linq
 
                 return false;
             }
-
-            protected override void OnGetEnumerator(CancellationToken cancellationToken)
-            {
-                // REVIEW: Wire cancellation to the functions.
-
-                _resource = _resourceFactory();
-                _enumerable = _enumerableFactory(_resource);
-
-                base.OnGetEnumerator(cancellationToken);
-            }
         }
 
         private sealed class UsingAsyncIteratorWithTask<TSource, TResource> : AsyncIterator<TSource> where TResource : IDisposable