Explorar el Código

Async variants of Join.

Bart De Smet hace 8 años
padre
commit
0e3f3bd005
Se han modificado 1 ficheros con 162 adiciones y 0 borrados
  1. 162 0
      Ix.NET/Source/System.Interactive.Async/Join.cs

+ 162 - 0
Ix.NET/Source/System.Interactive.Async/Join.cs

@@ -44,6 +44,40 @@ namespace System.Linq
             return new JoinAsyncIterator<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer);
         }
 
+        public static IAsyncEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(this IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, Task<TKey>> outerKeySelector, Func<TInner, Task<TKey>> innerKeySelector, Func<TOuter, TInner, Task<TResult>> resultSelector)
+        {
+            if (outer == null)
+                throw new ArgumentNullException(nameof(outer));
+            if (inner == null)
+                throw new ArgumentNullException(nameof(inner));
+            if (outerKeySelector == null)
+                throw new ArgumentNullException(nameof(outerKeySelector));
+            if (innerKeySelector == null)
+                throw new ArgumentNullException(nameof(innerKeySelector));
+            if (resultSelector == null)
+                throw new ArgumentNullException(nameof(resultSelector));
+
+            return new JoinAsyncIteratorWithTask<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer<TKey>.Default);
+        }
+
+        public static IAsyncEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(this IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, Task<TKey>> outerKeySelector, Func<TInner, Task<TKey>> innerKeySelector, Func<TOuter, TInner, Task<TResult>> resultSelector, IEqualityComparer<TKey> comparer)
+        {
+            if (outer == null)
+                throw new ArgumentNullException(nameof(outer));
+            if (inner == null)
+                throw new ArgumentNullException(nameof(inner));
+            if (outerKeySelector == null)
+                throw new ArgumentNullException(nameof(outerKeySelector));
+            if (innerKeySelector == null)
+                throw new ArgumentNullException(nameof(innerKeySelector));
+            if (resultSelector == null)
+                throw new ArgumentNullException(nameof(resultSelector));
+            if (comparer == null)
+                throw new ArgumentNullException(nameof(comparer));
+
+            return new JoinAsyncIteratorWithTask<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer);
+        }
+
         internal sealed class JoinAsyncIterator<TOuter, TInner, TKey, TResult> : AsyncIterator<TResult>
         {
             private readonly IAsyncEnumerable<TOuter> outer;
@@ -171,5 +205,133 @@ namespace System.Linq
                 return false;
             }
         }
+
+        internal sealed class JoinAsyncIteratorWithTask<TOuter, TInner, TKey, TResult> : AsyncIterator<TResult>
+        {
+            private readonly IAsyncEnumerable<TOuter> outer;
+            private readonly IAsyncEnumerable<TInner> inner;
+            private readonly Func<TOuter, Task<TKey>> outerKeySelector;
+            private readonly Func<TInner, Task<TKey>> innerKeySelector;
+            private readonly Func<TOuter, TInner, Task<TResult>> resultSelector;
+            private readonly IEqualityComparer<TKey> comparer;
+
+            private IAsyncEnumerator<TOuter> outerEnumerator;
+
+            public JoinAsyncIteratorWithTask(IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, Task<TKey>> outerKeySelector, Func<TInner, Task<TKey>> innerKeySelector, Func<TOuter, TInner, Task<TResult>> resultSelector, IEqualityComparer<TKey> comparer)
+            {
+                Debug.Assert(outer != null);
+                Debug.Assert(inner != null);
+                Debug.Assert(outerKeySelector != null);
+                Debug.Assert(innerKeySelector != null);
+                Debug.Assert(resultSelector != null);
+                Debug.Assert(comparer != null);
+
+                this.outer = outer;
+                this.inner = inner;
+                this.outerKeySelector = outerKeySelector;
+                this.innerKeySelector = innerKeySelector;
+                this.resultSelector = resultSelector;
+                this.comparer = comparer;
+            }
+
+            public override AsyncIterator<TResult> Clone()
+            {
+                return new JoinAsyncIteratorWithTask<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (outerEnumerator != null)
+                {
+                    await outerEnumerator.DisposeAsync().ConfigureAwait(false);
+                    outerEnumerator = null;
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            // State machine vars
+            private Internal.LookupWithTask<TKey, TInner> lookup;
+            private int count;
+            private TInner[] elements;
+            private int index;
+            private TOuter item;
+            private int mode;
+
+            private const int State_If = 1;
+            private const int State_DoLoop = 2;
+            private const int State_For = 3;
+            private const int State_While = 4;
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        outerEnumerator = outer.GetAsyncEnumerator();
+                        mode = State_If;
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        switch (mode)
+                        {
+                            case State_If:
+                                if (await outerEnumerator.MoveNextAsync().ConfigureAwait(false))
+                                {
+                                    lookup = await Internal.LookupWithTask<TKey, TInner>.CreateForJoinAsync(inner, innerKeySelector, comparer).ConfigureAwait(false);
+
+                                    if (lookup.Count != 0)
+                                    {
+                                        mode = State_DoLoop;
+                                        goto case State_DoLoop;
+                                    }
+                                }
+
+                                break;
+
+                            case State_DoLoop:
+                                item = outerEnumerator.Current;
+                                var g = lookup.GetGrouping(await outerKeySelector(item).ConfigureAwait(false), create: false);
+                                if (g != null)
+                                {
+                                    count = g._count;
+                                    elements = g._elements;
+                                    index = 0;
+                                    mode = State_For;
+                                    goto case State_For;
+                                }
+
+                                // advance to while
+                                mode = State_While;
+                                goto case State_While;
+
+                            case State_For:
+                                current = await resultSelector(item, elements[index]).ConfigureAwait(false);
+                                index++;
+                                if (index == count)
+                                {
+                                    mode = State_While;
+                                }
+
+                                return true;
+
+                            case State_While:
+                                var hasNext = await outerEnumerator.MoveNextAsync().ConfigureAwait(false);
+                                if (hasNext)
+                                {
+                                    goto case State_DoLoop;
+                                }
+
+                                break;
+                        }
+
+                        await DisposeAsync().ConfigureAwait(false);
+                        break;
+                }
+
+                return false;
+            }
+        }
     }
 }