فهرست منبع

Async variants for Where and Select.

Bart De Smet 8 سال پیش
والد
کامیت
d5040a051a

+ 10 - 0
Ix.NET/Source/System.Interactive.Async/AsyncIterator.cs

@@ -112,11 +112,21 @@ namespace System.Linq
                 return new SelectEnumerableAsyncIterator<TSource, TResult>(this, selector);
             }
 
+            public virtual IAsyncEnumerable<TResult> Select<TResult>(Func<TSource, Task<TResult>> selector)
+            {
+                return new SelectEnumerableAsyncIteratorWithTask<TSource, TResult>(this, selector);
+            }
+
             public virtual IAsyncEnumerable<TSource> Where(Func<TSource, bool> predicate)
             {
                 return new WhereEnumerableAsyncIterator<TSource>(this, predicate);
             }
 
+            public virtual IAsyncEnumerable<TSource> Where(Func<TSource, Task<bool>> predicate)
+            {
+                return new WhereEnumerableAsyncIteratorWithTask<TSource>(this, predicate);
+            }
+
             protected abstract Task<bool> MoveNextCore();
 
             protected virtual void OnGetEnumerator()

+ 219 - 1
Ix.NET/Source/System.Interactive.Async/Select.cs

@@ -4,7 +4,6 @@
 
 using System.Collections.Generic;
 using System.Diagnostics;
-using System.Threading;
 using System.Threading.Tasks;
 
 namespace System.Linq
@@ -41,11 +40,46 @@ namespace System.Linq
             return new SelectEnumerableWithIndexAsyncIterator<TSource, TResult>(source, selector);
         }
 
+        public static IAsyncEnumerable<TResult> Select<TSource, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<TResult>> selector)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (selector == null)
+                throw new ArgumentNullException(nameof(selector));
+
+            if (source is AsyncIterator<TSource> iterator)
+            {
+                return iterator.Select(selector);
+            }
+
+            if (source is IList<TSource> ilist)
+            {
+                return new SelectIListIteratorWithTask<TSource, TResult>(ilist, selector);
+            }
+
+            return new SelectEnumerableAsyncIteratorWithTask<TSource, TResult>(source, selector);
+        }
+
+        public static IAsyncEnumerable<TResult> Select<TSource, TResult>(this IAsyncEnumerable<TSource> source, Func<TSource, int, Task<TResult>> selector)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (selector == null)
+                throw new ArgumentNullException(nameof(selector));
+
+            return new SelectEnumerableWithIndexAsyncIteratorWithTask<TSource, TResult>(source, selector);
+        }
+
         private static Func<TSource, TResult> CombineSelectors<TSource, TMiddle, TResult>(Func<TSource, TMiddle> selector1, Func<TMiddle, TResult> selector2)
         {
             return x => selector2(selector1(x));
         }
 
+        private static Func<TSource, Task<TResult>> CombineSelectors<TSource, TMiddle, TResult>(Func<TSource, Task<TMiddle>> selector1, Func<TMiddle, Task<TResult>> selector2)
+        {
+            return async x => await selector2(await selector1(x).ConfigureAwait(false)).ConfigureAwait(false);
+        }
+
         internal sealed class SelectEnumerableAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
         {
             private readonly Func<TSource, TResult> selector;
@@ -229,5 +263,189 @@ namespace System.Linq
                 return false;
             }
         }
+
+        internal sealed class SelectEnumerableAsyncIteratorWithTask<TSource, TResult> : AsyncIterator<TResult>
+        {
+            private readonly Func<TSource, Task<TResult>> selector;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private IAsyncEnumerator<TSource> enumerator;
+
+            public SelectEnumerableAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, Task<TResult>> selector)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(selector != null);
+
+                this.source = source;
+                this.selector = selector;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new SelectEnumerableAsyncIteratorWithTask<TSource, TResult>(source, selector);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (enumerator != null)
+                {
+                    await enumerator.DisposeAsync().ConfigureAwait(false);
+                    enumerator = null;
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, Task<TResult1>> selector)
+            {
+                return new SelectEnumerableAsyncIteratorWithTask<TSource, TResult1>(source, CombineSelectors(this.selector, selector));
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetAsyncEnumerator();
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        if (await enumerator.MoveNextAsync()
+                                            .ConfigureAwait(false))
+                        {
+                            current = await selector(enumerator.Current).ConfigureAwait(false);
+                            return true;
+                        }
+
+                        await DisposeAsync().ConfigureAwait(false);
+                        break;
+                }
+
+                return false;
+            }
+        }
+
+        internal sealed class SelectEnumerableWithIndexAsyncIteratorWithTask<TSource, TResult> : AsyncIterator<TResult>
+        {
+            private readonly Func<TSource, int, Task<TResult>> selector;
+            private readonly IAsyncEnumerable<TSource> source;
+            private IAsyncEnumerator<TSource> enumerator;
+            private int index;
+
+            public SelectEnumerableWithIndexAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, int, Task<TResult>> selector)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(selector != null);
+
+                this.source = source;
+                this.selector = selector;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new SelectEnumerableWithIndexAsyncIteratorWithTask<TSource, TResult>(source, selector);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (enumerator != null)
+                {
+                    await enumerator.DisposeAsync().ConfigureAwait(false);
+                    enumerator = null;
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetAsyncEnumerator();
+                        index = -1;
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        if (await enumerator.MoveNextAsync()
+                                            .ConfigureAwait(false))
+                        {
+                            checked
+                            {
+                                index++;
+                            }
+                            current = await selector(enumerator.Current, index).ConfigureAwait(false);
+                            return true;
+                        }
+
+                        await DisposeAsync().ConfigureAwait(false);
+                        break;
+                }
+
+                return false;
+            }
+        }
+
+        internal sealed class SelectIListIteratorWithTask<TSource, TResult> : AsyncIterator<TResult>
+        {
+            private readonly Func<TSource, Task<TResult>> selector;
+            private readonly IList<TSource> source;
+            private IEnumerator<TSource> enumerator;
+
+            public SelectIListIteratorWithTask(IList<TSource> source, Func<TSource, Task<TResult>> selector)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(selector != null);
+
+                this.source = source;
+                this.selector = selector;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new SelectIListIteratorWithTask<TSource, TResult>(source, selector);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            public override IAsyncEnumerable<TResult1> Select<TResult1>(Func<TResult, Task<TResult1>> selector)
+            {
+                return new SelectIListIteratorWithTask<TSource, TResult1>(source, CombineSelectors(this.selector, selector));
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetEnumerator();
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        if (enumerator.MoveNext())
+                        {
+                            current = await selector(enumerator.Current).ConfigureAwait(false);
+                            return true;
+                        }
+
+                        await DisposeAsync().ConfigureAwait(false);
+                        break;
+                }
+
+                return false;
+            }
+        }
     }
 }

+ 161 - 2
Ix.NET/Source/System.Interactive.Async/Where.cs

@@ -4,7 +4,6 @@
 
 using System.Collections.Generic;
 using System.Diagnostics;
-using System.Threading;
 using System.Threading.Tasks;
 
 namespace System.Linq
@@ -37,12 +36,42 @@ namespace System.Linq
             return new WhereEnumerableWithIndexAsyncIterator<TSource>(source, predicate);
         }
 
-        private static Func<TSource, bool> CombinePredicates<TSource>(Func<TSource, bool> predicate1, Func<TSource, bool> predicate2)
+        public static IAsyncEnumerable<TSource> Where<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (predicate == null)
+                throw new ArgumentNullException(nameof(predicate));
+
+            if (source is AsyncIterator<TSource> iterator)
+            {
+                return iterator.Where(predicate);
+            }
+
+            // TODO: Can we add array/list optimizations here, does it make sense?
+            return new WhereEnumerableAsyncIteratorWithTask<TSource>(source, predicate);
+        }
+
+        public static IAsyncEnumerable<TSource> Where<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, int, Task<bool>> predicate)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (predicate == null)
+                throw new ArgumentNullException(nameof(predicate));
 
+            return new WhereEnumerableWithIndexAsyncIteratorWithTask<TSource>(source, predicate);
+        }
+
+        private static Func<TSource, bool> CombinePredicates<TSource>(Func<TSource, bool> predicate1, Func<TSource, bool> predicate2)
         {
             return x => predicate1(x) && predicate2(x);
         }
 
+        private static Func<TSource, Task<bool>> CombinePredicates<TSource>(Func<TSource, Task<bool>> predicate1, Func<TSource, Task<bool>> predicate2)
+        {
+            return async x => await predicate1(x).ConfigureAwait(false) && await predicate2(x).ConfigureAwait(false);
+        }
+
         internal sealed class WhereEnumerableAsyncIterator<TSource> : AsyncIterator<TSource>
         {
             private readonly Func<TSource, bool> predicate;
@@ -178,6 +207,136 @@ namespace System.Linq
             }
         }
 
+        internal sealed class WhereEnumerableAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
+        {
+            private readonly Func<TSource, Task<bool>> predicate;
+            private readonly IAsyncEnumerable<TSource> source;
+            private IAsyncEnumerator<TSource> enumerator;
+
+            public WhereEnumerableAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, Task<bool>> predicate)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(predicate != null);
+
+                this.source = source;
+                this.predicate = predicate;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new WhereEnumerableAsyncIteratorWithTask<TSource>(source, predicate);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (enumerator != null)
+                {
+                    await enumerator.DisposeAsync().ConfigureAwait(false);
+                    enumerator = null;
+                }
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            public override IAsyncEnumerable<TSource> Where(Func<TSource, Task<bool>> predicate)
+            {
+                return new WhereEnumerableAsyncIteratorWithTask<TSource>(source, CombinePredicates(this.predicate, predicate));
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetAsyncEnumerator();
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        while (await enumerator.MoveNextAsync()
+                                               .ConfigureAwait(false))
+                        {
+                            var item = enumerator.Current;
+                            if (await predicate(item).ConfigureAwait(false))
+                            {
+                                current = item;
+                                return true;
+                            }
+                        }
+
+                        await DisposeAsync().ConfigureAwait(false);
+                        break;
+                }
+
+                return false;
+            }
+        }
+
+        internal sealed class WhereEnumerableWithIndexAsyncIteratorWithTask<TSource> : AsyncIterator<TSource>
+        {
+            private readonly Func<TSource, int, Task<bool>> predicate;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private IAsyncEnumerator<TSource> enumerator;
+            private int index;
+
+            public WhereEnumerableWithIndexAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, int, Task<bool>> predicate)
+            {
+                Debug.Assert(source != null);
+                Debug.Assert(predicate != null);
+
+                this.source = source;
+                this.predicate = predicate;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new WhereEnumerableWithIndexAsyncIteratorWithTask<TSource>(source, predicate);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (enumerator != null)
+                {
+                    await enumerator.DisposeAsync().ConfigureAwait(false);
+                    enumerator = null;
+                }
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetAsyncEnumerator();
+                        index = -1;
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        while (await enumerator.MoveNextAsync()
+                                               .ConfigureAwait(false))
+                        {
+                            checked
+                            {
+                                index++;
+                            }
+                            var item = enumerator.Current;
+                            if (await predicate(item, index).ConfigureAwait(false))
+                            {
+                                current = item;
+                                return true;
+                            }
+                        }
+
+                        await DisposeAsync().ConfigureAwait(false);
+                        break;
+                }
+
+                return false;
+            }
+        }
+
         internal sealed class WhereSelectEnumerableAsyncIterator<TSource, TResult> : AsyncIterator<TResult>
         {
             private readonly Func<TSource, bool> predicate;