Browse Source

Use await-based GroupJoin to fix deadlock - addresses #182

Oren Novotny 9 years ago
parent
commit
b589023759
1 changed files with 161 additions and 80 deletions
  1. 161 80
      Ix.NET/Source/System.Interactive.Async/AsyncEnumerable.Multiple.cs

+ 161 - 80
Ix.NET/Source/System.Interactive.Async/AsyncEnumerable.Multiple.cs

@@ -356,87 +356,8 @@ namespace System.Linq
             if (comparer == null)
                 throw new ArgumentNullException("comparer");
 
-            return Create(() =>
-            {
-                var innerMap = default(Task<ILookup<TKey, TInner>>);
-                var getInnerMap = new Func<CancellationToken, Task<ILookup<TKey, TInner>>>(ct =>
-                {
-                    if (innerMap == null)
-                        innerMap = inner.ToLookup(innerKeySelector, comparer, ct);
-
-                    return innerMap;
-                });
-
-                var outerE = outer.GetEnumerator();
-                var current = default(TResult);
-
-                var cts = new CancellationTokenDisposable();
-                var d = Disposable.Create(cts, outerE);
-
-                var f = default(Action<TaskCompletionSource<bool>, CancellationToken>);
-                f = (tcs, ct) =>
-                {
-                    getInnerMap(ct).Then(ti =>
-                    {
-                        ti.Handle(tcs, map =>
-                        {
-                            outerE.MoveNext(ct).Then(to =>
-                            {
-                                to.Handle(tcs, res =>
-                                {
-                                    if (res)
-                                    {
-                                        var element = outerE.Current;
-                                        var key = default(TKey);
-
-                                        try
-                                        {
-                                            key = outerKeySelector(element);
-                                        }
-                                        catch (Exception ex)
-                                        {
-                                            tcs.TrySetException(ex);
-                                            return;
-                                        }
-
-                                        var innerE = default(IAsyncEnumerable<TInner>);
-                                        if (!map.Contains(key))
-                                            innerE = AsyncEnumerable.Empty<TInner>();
-                                        else
-                                            innerE = map[key].ToAsyncEnumerable();
-
-                                        try
-                                        {
-                                            current = resultSelector(element, innerE);
-                                        }
-                                        catch (Exception ex)
-                                        {
-                                            tcs.TrySetException(ex);
-                                            return;
-                                        }
-
-                                        tcs.TrySetResult(true);
-                                    }
-                                    else
-                                    {
-                                        tcs.TrySetResult(false);
-                                    }
-                                });
-                            });
-                        });
-                    });
-                };
 
-                return Create(
-                    (ct, tcs) =>
-                    {
-                        f(tcs, cts.Token);
-                        return tcs.Task.UsingEnumerator(outerE);
-                    },
-                    () => current,
-                    d.Dispose
-                );
-            });
+            return new GroupJoinAsyncEnumerable<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer);
         }
 
         public static IAsyncEnumerable<TResult> GroupJoin<TOuter, TInner, TKey, TResult>(this IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, IAsyncEnumerable<TInner>, TResult> resultSelector)
@@ -455,6 +376,166 @@ namespace System.Linq
             return outer.GroupJoin(inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer<TKey>.Default);
         }
 
+
+        private sealed class GroupJoinAsyncEnumerable<TOuter, TInner, TKey, TResult> : IAsyncEnumerable<TResult>
+        {
+            private readonly IAsyncEnumerable<TOuter> _outer;
+            private readonly IAsyncEnumerable<TInner> _inner;
+            private readonly Func<TOuter, TKey> _outerKeySelector;
+            private readonly Func<TInner, TKey> _innerKeySelector;
+            private readonly Func<TOuter, IAsyncEnumerable<TInner>, TResult> _resultSelector;
+            private readonly IEqualityComparer<TKey> _comparer;
+
+            public GroupJoinAsyncEnumerable(
+                IAsyncEnumerable<TOuter> outer,
+                IAsyncEnumerable<TInner> inner,
+                Func<TOuter, TKey> outerKeySelector,
+                Func<TInner, TKey> innerKeySelector,
+                Func<TOuter, IAsyncEnumerable<TInner>, TResult> resultSelector,
+                IEqualityComparer<TKey> comparer)
+            {
+                _outer = outer;
+                _inner = inner;
+                _outerKeySelector = outerKeySelector;
+                _innerKeySelector = innerKeySelector;
+                _resultSelector = resultSelector;
+                _comparer = comparer;
+            }
+
+            public IAsyncEnumerator<TResult> GetEnumerator()
+                => new GroupJoinAsyncEnumerator(
+                    _outer.GetEnumerator(),
+                    _inner.GetEnumerator(),
+                    _outerKeySelector,
+                    _innerKeySelector,
+                    _resultSelector,
+                    _comparer);
+
+            private sealed class GroupJoinAsyncEnumerator : IAsyncEnumerator<TResult>
+            {
+                private readonly IAsyncEnumerator<TOuter> _outer;
+                private readonly IAsyncEnumerator<TInner> _inner;
+                private readonly Func<TOuter, TKey> _outerKeySelector;
+                private readonly Func<TInner, TKey> _innerKeySelector;
+                private readonly Func<TOuter, IAsyncEnumerable<TInner>, TResult> _resultSelector;
+                private readonly IEqualityComparer<TKey> _comparer;
+
+                private Dictionary<TKey, List<TInner>> _innerGroups;
+
+                public GroupJoinAsyncEnumerator(
+                    IAsyncEnumerator<TOuter> outer,
+                    IAsyncEnumerator<TInner> inner,
+                    Func<TOuter, TKey> outerKeySelector,
+                    Func<TInner, TKey> innerKeySelector,
+                    Func<TOuter, IAsyncEnumerable<TInner>, TResult> resultSelector,
+                    IEqualityComparer<TKey> comparer)
+                {
+                    _outer = outer;
+                    _inner = inner;
+                    _outerKeySelector = outerKeySelector;
+                    _innerKeySelector = innerKeySelector;
+                    _resultSelector = resultSelector;
+                    _comparer = comparer;
+                }
+
+                public async Task<bool> MoveNext(CancellationToken cancellationToken)
+                {
+                    List<TInner> group;
+
+                    if (!await _outer.MoveNext(cancellationToken))
+                    {
+                        return false;
+                    }
+
+                    if (_innerGroups == null)
+                    {
+                        _innerGroups = new Dictionary<TKey, List<TInner>>();
+
+                        while (await _inner.MoveNext(cancellationToken))
+                        {
+                            var inner = _inner.Current;
+                            var innerKey = _innerKeySelector(inner);
+
+                            if (innerKey != null)
+                            {
+                                if (!_innerGroups.TryGetValue(innerKey, out group))
+                                {
+                                    _innerGroups.Add(innerKey, group = new List<TInner>());
+                                }
+
+                                group.Add(inner);
+                            }
+                        }
+                    }
+
+                    var outer = _outer.Current;
+                    var outerKey = _outerKeySelector(outer);
+
+                    Current
+                        = _resultSelector(
+                            outer,
+                            new AsyncEnumerableAdapter<TInner>(
+                                outerKey != null
+                                && _innerGroups.TryGetValue(outerKey, out group)
+                                    ? (IEnumerable<TInner>)group
+                                    : EmptyEnumerable<TInner>.Instance));
+
+                    return true;
+                }
+
+                public TResult Current { get; private set; }
+
+                public void Dispose()
+                {
+                    _inner.Dispose();
+                    _outer.Dispose();
+                }
+
+                private sealed class EmptyEnumerable<TElement>
+                {
+                    public static readonly TElement[] Instance = new TElement[0];
+                }
+            }
+        }
+
+        private sealed class AsyncEnumerableAdapter<T> : IAsyncEnumerable<T>
+        {
+            private readonly IEnumerable<T> _source;
+
+            public AsyncEnumerableAdapter(IEnumerable<T> source)
+            {
+                _source = source;
+            }
+
+            public IAsyncEnumerator<T> GetEnumerator()
+                => new AsyncEnumeratorAdapter(_source.GetEnumerator());
+
+            private sealed class AsyncEnumeratorAdapter : IAsyncEnumerator<T>
+            {
+                private readonly IEnumerator<T> _enumerator;
+
+                public AsyncEnumeratorAdapter(IEnumerator<T> enumerator)
+                {
+                    _enumerator = enumerator;
+                }
+
+                public Task<bool> MoveNext(CancellationToken cancellationToken)
+                {
+                    cancellationToken.ThrowIfCancellationRequested();
+
+#if HAS_AWAIT
+                    return Task.FromResult(_enumerator.MoveNext());
+#else
+                    return TaskEx.FromResult(_enumerator.MoveNext());
+#endif
+                }
+
+                public T Current => _enumerator.Current;
+
+                public void Dispose() => _enumerator.Dispose();
+            }
+        }
+
         public static IAsyncEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(this IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, TInner, TResult> resultSelector, IEqualityComparer<TKey> comparer)
         {
             if (outer == null)