Join.cs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Threading.Tasks;
  7. namespace System.Linq
  8. {
  9. public static partial class AsyncEnumerable
  10. {
  11. public static IAsyncEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(this IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, TInner, TResult> resultSelector)
  12. {
  13. if (outer == null)
  14. throw new ArgumentNullException(nameof(outer));
  15. if (inner == null)
  16. throw new ArgumentNullException(nameof(inner));
  17. if (outerKeySelector == null)
  18. throw new ArgumentNullException(nameof(outerKeySelector));
  19. if (innerKeySelector == null)
  20. throw new ArgumentNullException(nameof(innerKeySelector));
  21. if (resultSelector == null)
  22. throw new ArgumentNullException(nameof(resultSelector));
  23. return new JoinAsyncIterator<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer<TKey>.Default);
  24. }
  25. public static IAsyncEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(this IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, TInner, TResult> resultSelector, IEqualityComparer<TKey> comparer)
  26. {
  27. if (outer == null)
  28. throw new ArgumentNullException(nameof(outer));
  29. if (inner == null)
  30. throw new ArgumentNullException(nameof(inner));
  31. if (outerKeySelector == null)
  32. throw new ArgumentNullException(nameof(outerKeySelector));
  33. if (innerKeySelector == null)
  34. throw new ArgumentNullException(nameof(innerKeySelector));
  35. if (resultSelector == null)
  36. throw new ArgumentNullException(nameof(resultSelector));
  37. if (comparer == null)
  38. throw new ArgumentNullException(nameof(comparer));
  39. return new JoinAsyncIterator<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer);
  40. }
  41. public static IAsyncEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(this IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, Task<TKey>> outerKeySelector, Func<TInner, Task<TKey>> innerKeySelector, Func<TOuter, TInner, Task<TResult>> resultSelector)
  42. {
  43. if (outer == null)
  44. throw new ArgumentNullException(nameof(outer));
  45. if (inner == null)
  46. throw new ArgumentNullException(nameof(inner));
  47. if (outerKeySelector == null)
  48. throw new ArgumentNullException(nameof(outerKeySelector));
  49. if (innerKeySelector == null)
  50. throw new ArgumentNullException(nameof(innerKeySelector));
  51. if (resultSelector == null)
  52. throw new ArgumentNullException(nameof(resultSelector));
  53. return new JoinAsyncIteratorWithTask<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer<TKey>.Default);
  54. }
  55. public static IAsyncEnumerable<TResult> Join<TOuter, TInner, TKey, TResult>(this IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, Task<TKey>> outerKeySelector, Func<TInner, Task<TKey>> innerKeySelector, Func<TOuter, TInner, Task<TResult>> resultSelector, IEqualityComparer<TKey> comparer)
  56. {
  57. if (outer == null)
  58. throw new ArgumentNullException(nameof(outer));
  59. if (inner == null)
  60. throw new ArgumentNullException(nameof(inner));
  61. if (outerKeySelector == null)
  62. throw new ArgumentNullException(nameof(outerKeySelector));
  63. if (innerKeySelector == null)
  64. throw new ArgumentNullException(nameof(innerKeySelector));
  65. if (resultSelector == null)
  66. throw new ArgumentNullException(nameof(resultSelector));
  67. if (comparer == null)
  68. throw new ArgumentNullException(nameof(comparer));
  69. return new JoinAsyncIteratorWithTask<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer);
  70. }
  71. internal sealed class JoinAsyncIterator<TOuter, TInner, TKey, TResult> : AsyncIterator<TResult>
  72. {
  73. private readonly IAsyncEnumerable<TOuter> outer;
  74. private readonly IAsyncEnumerable<TInner> inner;
  75. private readonly Func<TOuter, TKey> outerKeySelector;
  76. private readonly Func<TInner, TKey> innerKeySelector;
  77. private readonly Func<TOuter, TInner, TResult> resultSelector;
  78. private readonly IEqualityComparer<TKey> comparer;
  79. private IAsyncEnumerator<TOuter> outerEnumerator;
  80. public JoinAsyncIterator(IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, TKey> outerKeySelector, Func<TInner, TKey> innerKeySelector, Func<TOuter, TInner, TResult> resultSelector, IEqualityComparer<TKey> comparer)
  81. {
  82. Debug.Assert(outer != null);
  83. Debug.Assert(inner != null);
  84. Debug.Assert(outerKeySelector != null);
  85. Debug.Assert(innerKeySelector != null);
  86. Debug.Assert(resultSelector != null);
  87. Debug.Assert(comparer != null);
  88. this.outer = outer;
  89. this.inner = inner;
  90. this.outerKeySelector = outerKeySelector;
  91. this.innerKeySelector = innerKeySelector;
  92. this.resultSelector = resultSelector;
  93. this.comparer = comparer;
  94. }
  95. public override AsyncIterator<TResult> Clone()
  96. {
  97. return new JoinAsyncIterator<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer);
  98. }
  99. public override async Task DisposeAsync()
  100. {
  101. if (outerEnumerator != null)
  102. {
  103. await outerEnumerator.DisposeAsync().ConfigureAwait(false);
  104. outerEnumerator = null;
  105. }
  106. await base.DisposeAsync().ConfigureAwait(false);
  107. }
  108. // State machine vars
  109. private Internal.Lookup<TKey, TInner> lookup;
  110. private int count;
  111. private TInner[] elements;
  112. private int index;
  113. private TOuter item;
  114. private int mode;
  115. private const int State_If = 1;
  116. private const int State_DoLoop = 2;
  117. private const int State_For = 3;
  118. private const int State_While = 4;
  119. protected override async Task<bool> MoveNextCore()
  120. {
  121. switch (state)
  122. {
  123. case AsyncIteratorState.Allocated:
  124. outerEnumerator = outer.GetAsyncEnumerator();
  125. mode = State_If;
  126. state = AsyncIteratorState.Iterating;
  127. goto case AsyncIteratorState.Iterating;
  128. case AsyncIteratorState.Iterating:
  129. switch (mode)
  130. {
  131. case State_If:
  132. if (await outerEnumerator.MoveNextAsync().ConfigureAwait(false))
  133. {
  134. lookup = await Internal.Lookup<TKey, TInner>.CreateForJoinAsync(inner, innerKeySelector, comparer).ConfigureAwait(false);
  135. if (lookup.Count != 0)
  136. {
  137. mode = State_DoLoop;
  138. goto case State_DoLoop;
  139. }
  140. }
  141. break;
  142. case State_DoLoop:
  143. item = outerEnumerator.Current;
  144. var g = lookup.GetGrouping(outerKeySelector(item), create: false);
  145. if (g != null)
  146. {
  147. count = g._count;
  148. elements = g._elements;
  149. index = 0;
  150. mode = State_For;
  151. goto case State_For;
  152. }
  153. // advance to while
  154. mode = State_While;
  155. goto case State_While;
  156. case State_For:
  157. current = resultSelector(item, elements[index]);
  158. index++;
  159. if (index == count)
  160. {
  161. mode = State_While;
  162. }
  163. return true;
  164. case State_While:
  165. var hasNext = await outerEnumerator.MoveNextAsync().ConfigureAwait(false);
  166. if (hasNext)
  167. {
  168. goto case State_DoLoop;
  169. }
  170. break;
  171. }
  172. await DisposeAsync().ConfigureAwait(false);
  173. break;
  174. }
  175. return false;
  176. }
  177. }
  178. internal sealed class JoinAsyncIteratorWithTask<TOuter, TInner, TKey, TResult> : AsyncIterator<TResult>
  179. {
  180. private readonly IAsyncEnumerable<TOuter> outer;
  181. private readonly IAsyncEnumerable<TInner> inner;
  182. private readonly Func<TOuter, Task<TKey>> outerKeySelector;
  183. private readonly Func<TInner, Task<TKey>> innerKeySelector;
  184. private readonly Func<TOuter, TInner, Task<TResult>> resultSelector;
  185. private readonly IEqualityComparer<TKey> comparer;
  186. private IAsyncEnumerator<TOuter> outerEnumerator;
  187. public JoinAsyncIteratorWithTask(IAsyncEnumerable<TOuter> outer, IAsyncEnumerable<TInner> inner, Func<TOuter, Task<TKey>> outerKeySelector, Func<TInner, Task<TKey>> innerKeySelector, Func<TOuter, TInner, Task<TResult>> resultSelector, IEqualityComparer<TKey> comparer)
  188. {
  189. Debug.Assert(outer != null);
  190. Debug.Assert(inner != null);
  191. Debug.Assert(outerKeySelector != null);
  192. Debug.Assert(innerKeySelector != null);
  193. Debug.Assert(resultSelector != null);
  194. Debug.Assert(comparer != null);
  195. this.outer = outer;
  196. this.inner = inner;
  197. this.outerKeySelector = outerKeySelector;
  198. this.innerKeySelector = innerKeySelector;
  199. this.resultSelector = resultSelector;
  200. this.comparer = comparer;
  201. }
  202. public override AsyncIterator<TResult> Clone()
  203. {
  204. return new JoinAsyncIteratorWithTask<TOuter, TInner, TKey, TResult>(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer);
  205. }
  206. public override async Task DisposeAsync()
  207. {
  208. if (outerEnumerator != null)
  209. {
  210. await outerEnumerator.DisposeAsync().ConfigureAwait(false);
  211. outerEnumerator = null;
  212. }
  213. await base.DisposeAsync().ConfigureAwait(false);
  214. }
  215. // State machine vars
  216. private Internal.LookupWithTask<TKey, TInner> lookup;
  217. private int count;
  218. private TInner[] elements;
  219. private int index;
  220. private TOuter item;
  221. private int mode;
  222. private const int State_If = 1;
  223. private const int State_DoLoop = 2;
  224. private const int State_For = 3;
  225. private const int State_While = 4;
  226. protected override async Task<bool> MoveNextCore()
  227. {
  228. switch (state)
  229. {
  230. case AsyncIteratorState.Allocated:
  231. outerEnumerator = outer.GetAsyncEnumerator();
  232. mode = State_If;
  233. state = AsyncIteratorState.Iterating;
  234. goto case AsyncIteratorState.Iterating;
  235. case AsyncIteratorState.Iterating:
  236. switch (mode)
  237. {
  238. case State_If:
  239. if (await outerEnumerator.MoveNextAsync().ConfigureAwait(false))
  240. {
  241. lookup = await Internal.LookupWithTask<TKey, TInner>.CreateForJoinAsync(inner, innerKeySelector, comparer).ConfigureAwait(false);
  242. if (lookup.Count != 0)
  243. {
  244. mode = State_DoLoop;
  245. goto case State_DoLoop;
  246. }
  247. }
  248. break;
  249. case State_DoLoop:
  250. item = outerEnumerator.Current;
  251. var g = lookup.GetGrouping(await outerKeySelector(item).ConfigureAwait(false), create: false);
  252. if (g != null)
  253. {
  254. count = g._count;
  255. elements = g._elements;
  256. index = 0;
  257. mode = State_For;
  258. goto case State_For;
  259. }
  260. // advance to while
  261. mode = State_While;
  262. goto case State_While;
  263. case State_For:
  264. current = await resultSelector(item, elements[index]).ConfigureAwait(false);
  265. index++;
  266. if (index == count)
  267. {
  268. mode = State_While;
  269. }
  270. return true;
  271. case State_While:
  272. var hasNext = await outerEnumerator.MoveNextAsync().ConfigureAwait(false);
  273. if (hasNext)
  274. {
  275. goto case State_DoLoop;
  276. }
  277. break;
  278. }
  279. await DisposeAsync().ConfigureAwait(false);
  280. break;
  281. }
  282. return false;
  283. }
  284. }
  285. }
  286. }