Count.cs 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Linq;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. namespace System.Linq
  10. {
  11. public static partial class AsyncEnumerable
  12. {
  13. public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
  14. {
  15. if (source == null)
  16. throw new ArgumentNullException(nameof(source));
  17. var listProv = source as IIListProvider<TSource>;
  18. if (listProv != null)
  19. {
  20. return listProv.GetCountAsync(onlyIfCheap: false, cancellationToken: cancellationToken);
  21. }
  22. return source.Aggregate(0, (c, _) => checked(c + 1), cancellationToken);
  23. }
  24. public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken)
  25. {
  26. if (source == null)
  27. throw new ArgumentNullException(nameof(source));
  28. if (predicate == null)
  29. throw new ArgumentNullException(nameof(predicate));
  30. return source.Where(predicate)
  31. .Aggregate(0, (c, _) => checked(c + 1), cancellationToken);
  32. }
  33. public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source)
  34. {
  35. if (source == null)
  36. throw new ArgumentNullException(nameof(source));
  37. return Count(source, CancellationToken.None);
  38. }
  39. public static Task<int> Count<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  40. {
  41. if (source == null)
  42. throw new ArgumentNullException(nameof(source));
  43. if (predicate == null)
  44. throw new ArgumentNullException(nameof(predicate));
  45. return Count(source, predicate, CancellationToken.None);
  46. }
  47. public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
  48. {
  49. if (source == null)
  50. throw new ArgumentNullException(nameof(source));
  51. return source.Aggregate(0L, (c, _) => checked(c + 1), cancellationToken);
  52. }
  53. public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken)
  54. {
  55. if (source == null)
  56. throw new ArgumentNullException(nameof(source));
  57. if (predicate == null)
  58. throw new ArgumentNullException(nameof(predicate));
  59. return source.Where(predicate)
  60. .Aggregate(0L, (c, _) => checked(c + 1), cancellationToken);
  61. }
  62. public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source)
  63. {
  64. if (source == null)
  65. throw new ArgumentNullException(nameof(source));
  66. return LongCount(source, CancellationToken.None);
  67. }
  68. public static Task<long> LongCount<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate)
  69. {
  70. if (source == null)
  71. throw new ArgumentNullException(nameof(source));
  72. if (predicate == null)
  73. throw new ArgumentNullException(nameof(predicate));
  74. return LongCount(source, predicate, CancellationToken.None);
  75. }
  76. }
  77. }