浏览代码

Async variants of Create and Defer.

Bart De Smet 8 年之前
父节点
当前提交
f92678e70b
共有 2 个文件被更改,包括 106 次插入0 次删除
  1. 97 0
      Ix.NET/Source/System.Interactive.Async/Create.cs
  2. 9 0
      Ix.NET/Source/System.Interactive.Async/Defer.cs

+ 97 - 0
Ix.NET/Source/System.Interactive.Async/Create.cs

@@ -4,6 +4,7 @@
 
 using System.Collections.Generic;
 using System.Diagnostics;
+using System.Threading;
 using System.Threading.Tasks;
 
 namespace System.Linq
@@ -18,6 +19,14 @@ namespace System.Linq
             return new AnonymousAsyncEnumerable<T>(getEnumerator);
         }
 
+        public static IAsyncEnumerable<T> CreateEnumerable<T>(Func<Task<IAsyncEnumerator<T>>> getEnumerator)
+        {
+            if (getEnumerator == null)
+                throw new ArgumentNullException(nameof(getEnumerator));
+
+            return new AnonymousAsyncEnumerableWithTask<T>(getEnumerator);
+        }
+
         public static IAsyncEnumerator<T> CreateEnumerator<T>(Func<Task<bool>> moveNext, Func<T> current, Func<Task> dispose)
         {
             if (moveNext == null)
@@ -60,6 +69,94 @@ namespace System.Linq
             public IAsyncEnumerator<T> GetAsyncEnumerator() => getEnumerator();
         }
 
+        private sealed class AnonymousAsyncEnumerableWithTask<T> : IAsyncEnumerable<T>
+        {
+            private readonly Func<Task<IAsyncEnumerator<T>>> getEnumerator;
+
+            public AnonymousAsyncEnumerableWithTask(Func<Task<IAsyncEnumerator<T>>> getEnumerator)
+            {
+                Debug.Assert(getEnumerator != null);
+
+                this.getEnumerator = getEnumerator;
+            }
+
+            public IAsyncEnumerator<T> GetAsyncEnumerator() => new Enumerator(getEnumerator);
+
+            private sealed class Enumerator : IAsyncEnumerator<T>
+            {
+                private Func<Task<IAsyncEnumerator<T>>> getEnumerator;
+                private IAsyncEnumerator<T> enumerator;
+
+                public Enumerator(Func<Task<IAsyncEnumerator<T>>> getEnumerator)
+                {
+                    Debug.Assert(getEnumerator != null);
+
+                    this.getEnumerator = getEnumerator;
+                }
+
+                public T Current
+                {
+                    get
+                    {
+                        if (enumerator == null)
+                            throw new InvalidOperationException();
+
+                        return enumerator.Current;
+                    }
+                }
+
+                public async Task DisposeAsync()
+                {
+                    var old = Interlocked.Exchange(ref enumerator, DisposedEnumerator.Instance);
+
+                    if (enumerator != null)
+                    {
+                        await enumerator.DisposeAsync().ConfigureAwait(false);
+                    }
+                }
+
+                public Task<bool> MoveNextAsync()
+                {
+                    if (enumerator == null)
+                    {
+                        return InitAndMoveNextAsync();
+                    }
+
+                    return enumerator.MoveNextAsync();
+                }
+
+                private async Task<bool> InitAndMoveNextAsync()
+                {
+                    try
+                    {
+                        enumerator = await getEnumerator().ConfigureAwait(false);
+                    }
+                    catch (Exception ex)
+                    {
+                        enumerator = Throw<T>(ex).GetAsyncEnumerator();
+                        throw;
+                    }
+                    finally
+                    {
+                        getEnumerator = null;
+                    }
+
+                    return await enumerator.MoveNextAsync().ConfigureAwait(false);
+                }
+
+                private sealed class DisposedEnumerator : IAsyncEnumerator<T>
+                {
+                    public static readonly DisposedEnumerator Instance = new DisposedEnumerator();
+
+                    public T Current => throw new ObjectDisposedException("this");
+
+                    public Task DisposeAsync() => TaskExt.CompletedTask;
+
+                    public Task<bool> MoveNextAsync() => throw new ObjectDisposedException("this");
+                }
+            }
+        }
+
         private sealed class AnonymousAsyncIterator<T> : AsyncIterator<T>
         {
             private readonly Func<T> currentFunc;

+ 9 - 0
Ix.NET/Source/System.Interactive.Async/Defer.cs

@@ -3,6 +3,7 @@
 // See the LICENSE file in the project root for more information. 
 
 using System.Collections.Generic;
+using System.Threading.Tasks;
 
 namespace System.Linq
 {
@@ -15,5 +16,13 @@ namespace System.Linq
 
             return CreateEnumerable(() => factory().GetAsyncEnumerator());
         }
+
+        public static IAsyncEnumerable<TSource> Defer<TSource>(Func<Task<IAsyncEnumerable<TSource>>> factory)
+        {
+            if (factory == null)
+                throw new ArgumentNullException(nameof(factory));
+
+            return CreateEnumerable(async () => (await factory().ConfigureAwait(false)).GetAsyncEnumerator());
+        }
     }
 }