瀏覽代碼

Async variant of using.

Bart De Smet 8 年之前
父節點
當前提交
bcf2c05a15
共有 2 個文件被更改,包括 81 次插入4 次删除
  1. 77 0
      Ix.NET/Source/System.Interactive.Async/Using.cs
  2. 4 4
      Ix.NET/Source/Tests/AsyncTests.Creation.cs

+ 77 - 0
Ix.NET/Source/System.Interactive.Async/Using.cs

@@ -20,6 +20,16 @@ namespace System.Linq
             return new UsingAsyncIterator<TSource, TResource>(resourceFactory, enumerableFactory);
         }
 
+        public static IAsyncEnumerable<TSource> Using<TSource, TResource>(Func<Task<TResource>> resourceFactory, Func<TResource, Task<IAsyncEnumerable<TSource>>> enumerableFactory) where TResource : IDisposable
+        {
+            if (resourceFactory == null)
+                throw new ArgumentNullException(nameof(resourceFactory));
+            if (enumerableFactory == null)
+                throw new ArgumentNullException(nameof(enumerableFactory));
+
+            return new UsingAsyncIteratorWithTask<TSource, TResource>(resourceFactory, enumerableFactory);
+        }
+
         private sealed class UsingAsyncIterator<TSource, TResource> : AsyncIterator<TSource> where TResource : IDisposable
         {
             private readonly Func<TResource, IAsyncEnumerable<TSource>> enumerableFactory;
@@ -91,5 +101,72 @@ namespace System.Linq
                 base.OnGetEnumerator();
             }
         }
+
+        private sealed class UsingAsyncIteratorWithTask<TSource, TResource> : AsyncIterator<TSource> where TResource : IDisposable
+        {
+            private readonly Func<TResource, Task<IAsyncEnumerable<TSource>>> enumerableFactory;
+            private readonly Func<Task<TResource>> resourceFactory;
+
+            private IAsyncEnumerable<TSource> enumerable;
+            private IAsyncEnumerator<TSource> enumerator;
+            private TResource resource;
+
+            public UsingAsyncIteratorWithTask(Func<Task<TResource>> resourceFactory, Func<TResource, Task<IAsyncEnumerable<TSource>>> enumerableFactory)
+            {
+                Debug.Assert(resourceFactory != null);
+                Debug.Assert(enumerableFactory != null);
+
+                this.resourceFactory = resourceFactory;
+                this.enumerableFactory = enumerableFactory;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new UsingAsyncIteratorWithTask<TSource, TResource>(resourceFactory, enumerableFactory);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (enumerator != null)
+                {
+                    await enumerator.DisposeAsync().ConfigureAwait(false);
+                    enumerator = null;
+                }
+
+                if (resource != null)
+                {
+                    resource.Dispose();
+                    resource = default(TResource);
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        resource = await resourceFactory().ConfigureAwait(false);
+                        enumerable = await enumerableFactory(resource).ConfigureAwait(false);
+
+                        enumerator = enumerable.GetAsyncEnumerator();
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        while (await enumerator.MoveNextAsync().ConfigureAwait(false))
+                        {
+                            current = enumerator.Current;
+                            return true;
+                        }
+
+                        await DisposeAsync().ConfigureAwait(false);
+                        break;
+                }
+
+                return false;
+            }
+        }
     }
 }

+ 4 - 4
Ix.NET/Source/Tests/AsyncTests.Creation.cs

@@ -5,9 +5,9 @@
 using System;
 using System.Collections.Generic;
 using System.Linq;
-using Xunit;
-using System.Threading.Tasks;
 using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
 
 namespace Tests
 {
@@ -266,8 +266,8 @@ namespace Tests
         [Fact]
         public void Using_Null()
         {
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Using<int, IDisposable>(null, _ => null));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Using<int, IDisposable>(() => new MyD(null), null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Using<int, IDisposable>(null, _ => default(IAsyncEnumerable<int>)));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Using<int, IDisposable>(() => new MyD(null), default(Func<IDisposable, IAsyncEnumerable<int>>)));
         }
 
         [Fact]