Average.Generated.tt 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. <#@ template debug="false" hostspecific="false" language="C#" #>
  2. <#@ assembly name="System.Core" #>
  3. <#@ import namespace="System.Linq" #>
  4. <#@ import namespace="System.Text" #>
  5. <#@ import namespace="System.Collections.Generic" #>
  6. <#@ output extension=".cs" #>
  7. // Licensed to the .NET Foundation under one or more agreements.
  8. // The .NET Foundation licenses this file to you under the MIT License.
  9. // See the LICENSE file in the project root for more information.
  10. using System.Collections.Generic;
  11. using System.Threading;
  12. using System.Threading.Tasks;
  13. namespace System.Linq
  14. {
  15. <#
  16. // Although .NET 10.0's System.Linq.AsyncEnumerable defines AverageAsync, it is missing some features that were previously defined in System.Linq.Async:
  17. // Overloads of AverageAsync accepting a selector function (e.g. xs.AverageAsync(x => x.Value))
  18. // The AverageAwaitAsync variants accepting an async selector function returning a ValueTask
  19. // The AverageAwaitWithCancellationAsync variant where an async selector function accepts a CancellationToken
  20. // Since we are deprecating System.Linq.Async, these methods are now available there only in the runtime (lib) assemblies and not the
  21. // ref assemblies. So they are available for binary backwards compatibility, but are not visible during compilation. This is necessary
  22. // to ensure that code now gets the .NET 10 implementations of AverageAsync where available. But we want to enable code that was using
  23. // the functionality that .NET 10 did not replicate to be able to continue to work after removing the reference to System.Linq.Async.
  24. // So we've moved the relevant functionality back into System.Interactive.Async because that has always been the home of LINQ-like
  25. // features for IAsyncEnumerable<T> that aren't part of the core LINQ functionality.
  26. #>
  27. public static partial class AsyncEnumerableEx
  28. {
  29. <#
  30. var os = new[]
  31. {
  32. new { type = "int", res = "double", sum = "long" },
  33. new { type = "long", res = "double", sum = "long" },
  34. new { type = "float", res = "float", sum = "double" },
  35. new { type = "double", res = "double", sum = "double" },
  36. new { type = "decimal", res = "decimal", sum = "decimal" },
  37. new { type = "int?", res = "double?", sum = "long" },
  38. new { type = "long?", res = "double?", sum = "long" },
  39. new { type = "float?", res = "float?", sum = "double" },
  40. new { type = "double?", res = "double?", sum = "double" },
  41. new { type = "decimal?", res = "decimal?", sum = "decimal" },
  42. };
  43. foreach (var o in os)
  44. {
  45. var isNullable = o.type.EndsWith("?");
  46. var t = o.type.TrimEnd('?');
  47. string res = "";
  48. if (t == "int" || t == "long")
  49. res = "(double)sum / count";
  50. else if (t == "double" || t == "decimal")
  51. res = "sum / count";
  52. else if (t == "float")
  53. res = "(float)(sum / count)";
  54. var typeStr = o.type;
  55. if (isNullable) {
  56. typeStr = "Nullable{" + o.type.Substring(0, 1).ToUpper() + o.type.Substring(1, o.type.Length - 2) + "}";
  57. }
  58. #>
  59. /// <summary>
  60. /// Computes the average of an async-enumerable sequence of <see cref="<#=typeStr#>" /> values that are obtained by invoking a transform function on each element of the input sequence.
  61. /// </summary>
  62. /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
  63. /// <param name="source">A sequence of values to calculate the average of.</param>
  64. /// <param name="selector">A transform function to apply to each element.</param>
  65. /// <param name="cancellationToken">The optional cancellation token to be used for cancelling the sequence at any time.</param>
  66. /// <returns>An async-enumerable sequence containing a single element with the average of the sequence of values, or null if the source sequence is empty or contains only values that are null.</returns>
  67. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="selector"/> is null.</exception>
  68. /// <exception cref="InvalidOperationException">(Asynchronous) The source sequence is empty.</exception>
  69. /// <remarks>The return type of this operator differs from the corresponding operator on IEnumerable in order to retain asynchronous behavior.</remarks>
  70. public static ValueTask<<#=o.res#>> AverageAsync<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, <#=o.type#>> selector, CancellationToken cancellationToken = default)
  71. {
  72. if (source == null)
  73. throw Error.ArgumentNull(nameof(source));
  74. if (selector == null)
  75. throw Error.ArgumentNull(nameof(selector));
  76. return Core(source, selector, cancellationToken);
  77. static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable<TSource> source, Func<TSource, <#=o.type#>> selector, CancellationToken cancellationToken)
  78. {
  79. <#
  80. if (isNullable)
  81. {
  82. #>
  83. await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
  84. {
  85. while (await e.MoveNextAsync())
  86. {
  87. var v = selector(e.Current);
  88. if (v.HasValue)
  89. {
  90. <#=o.sum#> sum = v.GetValueOrDefault();
  91. long count = 1;
  92. checked
  93. {
  94. while (await e.MoveNextAsync())
  95. {
  96. v = selector(e.Current);
  97. if (v.HasValue)
  98. {
  99. sum += v.GetValueOrDefault();
  100. ++count;
  101. }
  102. }
  103. }
  104. return <#=res#>;
  105. }
  106. }
  107. }
  108. return null;
  109. <#
  110. }
  111. else
  112. {
  113. #>
  114. await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
  115. {
  116. if (!await e.MoveNextAsync())
  117. {
  118. throw Error.NoElements();
  119. }
  120. <#=o.sum#> sum = selector(e.Current);
  121. long count = 1;
  122. checked
  123. {
  124. while (await e.MoveNextAsync())
  125. {
  126. sum += selector(e.Current);
  127. ++count;
  128. }
  129. }
  130. return <#=res#>;
  131. }
  132. <#
  133. }
  134. #>
  135. }
  136. }
  137. /// <summary>
  138. /// Computes the average of an async-enumerable sequence of <see cref="int"/> values that are obtained by invoking an asynchronous transform function on each element of the source sequence and awaiting the result.
  139. /// </summary>
  140. /// <typeparam name="TSource">The type of elements in the source sequence.</typeparam>
  141. /// <param name="source">An async-enumerable sequence of values to compute the average of.</param>
  142. /// <param name="selector">A transform function to invoke and await on each element of the source sequence.</param>
  143. /// <param name="cancellationToken">An optional cancellation token for cancelling the sequence at any time.</param>
  144. /// <returns>A ValueTask containing the average of the sequence of values.</returns>
  145. /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="selector"/> is <see langword="null"/>.</exception>
  146. /// <exception cref="InvalidOperationException">The source sequence is empty.</exception>
  147. /// <remarks>The return type of this operator differs from the corresponding operator on IEnumerable in order to retain asynchronous behavior.</remarks>
  148. [GenerateAsyncOverload]
  149. private static ValueTask<<#=o.res#>> AverageAsyncCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, ValueTask<<#=o.type#>>> selector, CancellationToken cancellationToken = default)
  150. {
  151. if (source == null)
  152. throw Error.ArgumentNull(nameof(source));
  153. if (selector == null)
  154. throw Error.ArgumentNull(nameof(selector));
  155. return Core(source, selector, cancellationToken);
  156. static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable<TSource> source, Func<TSource, ValueTask<<#=o.type#>>> selector, CancellationToken cancellationToken)
  157. {
  158. <#
  159. if (isNullable)
  160. {
  161. #>
  162. await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
  163. {
  164. while (await e.MoveNextAsync())
  165. {
  166. var v = await selector(e.Current).ConfigureAwait(false);
  167. if (v.HasValue)
  168. {
  169. <#=o.sum#> sum = v.GetValueOrDefault();
  170. long count = 1;
  171. checked
  172. {
  173. while (await e.MoveNextAsync())
  174. {
  175. v = await selector(e.Current).ConfigureAwait(false);
  176. if (v.HasValue)
  177. {
  178. sum += v.GetValueOrDefault();
  179. ++count;
  180. }
  181. }
  182. }
  183. return <#=res#>;
  184. }
  185. }
  186. }
  187. return null;
  188. <#
  189. }
  190. else
  191. {
  192. #>
  193. await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
  194. {
  195. if (!await e.MoveNextAsync())
  196. {
  197. throw Error.NoElements();
  198. }
  199. <#=o.sum#> sum = await selector(e.Current).ConfigureAwait(false);
  200. long count = 1;
  201. checked
  202. {
  203. while (await e.MoveNextAsync())
  204. {
  205. sum += await selector(e.Current).ConfigureAwait(false);
  206. ++count;
  207. }
  208. }
  209. return <#=res#>;
  210. }
  211. <#
  212. }
  213. #>
  214. }
  215. }
  216. [GenerateAsyncOverload]
  217. private static ValueTask<<#=o.res#>> AverageAsyncCore<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, ValueTask<<#=o.type#>>> selector, CancellationToken cancellationToken = default)
  218. {
  219. if (source == null)
  220. throw Error.ArgumentNull(nameof(source));
  221. if (selector == null)
  222. throw Error.ArgumentNull(nameof(selector));
  223. return Core(source, selector, cancellationToken);
  224. static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, ValueTask<<#=o.type#>>> selector, CancellationToken cancellationToken)
  225. {
  226. <#
  227. if (isNullable)
  228. {
  229. #>
  230. await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
  231. {
  232. while (await e.MoveNextAsync())
  233. {
  234. var v = await selector(e.Current, cancellationToken).ConfigureAwait(false);
  235. if (v.HasValue)
  236. {
  237. <#=o.sum#> sum = v.GetValueOrDefault();
  238. long count = 1;
  239. checked
  240. {
  241. while (await e.MoveNextAsync())
  242. {
  243. v = await selector(e.Current, cancellationToken).ConfigureAwait(false);
  244. if (v.HasValue)
  245. {
  246. sum += v.GetValueOrDefault();
  247. ++count;
  248. }
  249. }
  250. }
  251. return <#=res#>;
  252. }
  253. }
  254. }
  255. return null;
  256. <#
  257. }
  258. else
  259. {
  260. #>
  261. await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
  262. {
  263. if (!await e.MoveNextAsync())
  264. {
  265. throw Error.NoElements();
  266. }
  267. <#=o.sum#> sum = await selector(e.Current, cancellationToken).ConfigureAwait(false);
  268. long count = 1;
  269. checked
  270. {
  271. while (await e.MoveNextAsync())
  272. {
  273. sum += await selector(e.Current, cancellationToken).ConfigureAwait(false);
  274. ++count;
  275. }
  276. }
  277. return <#=res#>;
  278. }
  279. <#
  280. }
  281. #>
  282. }
  283. }
  284. <#
  285. }
  286. #>
  287. }
  288. }