Intersect.cs 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Linq;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. namespace System.Linq
  10. {
  11. public static partial class AsyncEnumerable
  12. {
  13. public static IAsyncEnumerable<TSource> Intersect<TSource>(this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
  14. {
  15. if (first == null)
  16. throw new ArgumentNullException(nameof(first));
  17. if (second == null)
  18. throw new ArgumentNullException(nameof(second));
  19. if (comparer == null)
  20. throw new ArgumentNullException(nameof(comparer));
  21. return Create(() =>
  22. {
  23. var e = first.GetEnumerator();
  24. var cts = new CancellationTokenDisposable();
  25. var d = Disposable.Create(cts, e);
  26. var mapTask = default(Task<Dictionary<TSource, TSource>>);
  27. var getMapTask = new Func<CancellationToken, Task<Dictionary<TSource, TSource>>>(
  28. ct =>
  29. {
  30. if (mapTask == null)
  31. mapTask = second.ToDictionary(x => x, comparer, ct);
  32. return mapTask;
  33. });
  34. var f = default(Func<CancellationToken, Task<bool>>);
  35. f = async ct =>
  36. {
  37. if (await e.MoveNext(ct)
  38. .Zip(getMapTask(ct), (b, _) => b)
  39. .ConfigureAwait(false))
  40. {
  41. // Note: Result here is safe because the task
  42. // was completed in the Zip() call above
  43. if (mapTask.Result.ContainsKey(e.Current))
  44. return true;
  45. return await f(ct)
  46. .ConfigureAwait(false);
  47. }
  48. return false;
  49. };
  50. return Create(
  51. f,
  52. () => e.Current,
  53. d.Dispose,
  54. e
  55. );
  56. });
  57. }
  58. public static IAsyncEnumerable<TSource> Intersect<TSource>(this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second)
  59. {
  60. if (first == null)
  61. throw new ArgumentNullException(nameof(first));
  62. if (second == null)
  63. throw new ArgumentNullException(nameof(second));
  64. return first.Intersect(second, EqualityComparer<TSource>.Default);
  65. }
  66. }
  67. }