Intersect.cs 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Linq;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. namespace System.Linq
  10. {
  11. public static partial class AsyncEnumerable
  12. {
  13. public static IAsyncEnumerable<TSource> Intersect<TSource>(this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second, IEqualityComparer<TSource> comparer)
  14. {
  15. if (first == null)
  16. throw new ArgumentNullException(nameof(first));
  17. if (second == null)
  18. throw new ArgumentNullException(nameof(second));
  19. if (comparer == null)
  20. throw new ArgumentNullException(nameof(comparer));
  21. return CreateEnumerable(
  22. () =>
  23. {
  24. var e = first.GetEnumerator();
  25. var cts = new CancellationTokenDisposable();
  26. var d = Disposable.Create(cts, e);
  27. var mapTask = default(Task<Dictionary<TSource, TSource>>);
  28. var getMapTask = new Func<CancellationToken, Task<Dictionary<TSource, TSource>>>(
  29. ct =>
  30. {
  31. if (mapTask == null)
  32. mapTask = second.ToDictionary(x => x, comparer, ct);
  33. return mapTask;
  34. });
  35. var f = default(Func<CancellationToken, Task<bool>>);
  36. f = async ct =>
  37. {
  38. if (await e.MoveNext(ct)
  39. .Zip(getMapTask(ct), (b, _) => b)
  40. .ConfigureAwait(false))
  41. {
  42. // Note: Result here is safe because the task
  43. // was completed in the Zip() call above
  44. if (mapTask.Result.ContainsKey(e.Current))
  45. return true;
  46. return await f(ct)
  47. .ConfigureAwait(false);
  48. }
  49. return false;
  50. };
  51. return CreateEnumerator(
  52. f,
  53. () => e.Current,
  54. d.Dispose,
  55. e
  56. );
  57. });
  58. }
  59. public static IAsyncEnumerable<TSource> Intersect<TSource>(this IAsyncEnumerable<TSource> first, IAsyncEnumerable<TSource> second)
  60. {
  61. if (first == null)
  62. throw new ArgumentNullException(nameof(first));
  63. if (second == null)
  64. throw new ArgumentNullException(nameof(second));
  65. return first.Intersect(second, EqualityComparer<TSource>.Default);
  66. }
  67. }
  68. }