<#@ template debug="false" hostspecific="false" language="C#" #> <#@ assembly name="System.Core" #> <#@ import namespace="System.Linq" #> <#@ import namespace="System.Text" #> <#@ import namespace="System.Collections.Generic" #> <#@ output extension=".cs" #> // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { <# var os = new[] { new { type = "int", res = "double", sum = "long" }, new { type = "long", res = "double", sum = "long" }, new { type = "float", res = "float", sum = "double" }, new { type = "double", res = "double", sum = "double" }, new { type = "decimal", res = "decimal", sum = "decimal" }, new { type = "int?", res = "double?", sum = "long" }, new { type = "long?", res = "double?", sum = "long" }, new { type = "float?", res = "float?", sum = "double" }, new { type = "double?", res = "double?", sum = "double" }, new { type = "decimal?", res = "decimal?", sum = "decimal" }, }; foreach (var o in os) { var isNullable = o.type.EndsWith("?"); var t = o.type.TrimEnd('?'); string res = ""; if (t == "int" || t == "long") res = "(double)sum / count"; else if (t == "double" || t == "decimal") res = "sum / count"; else if (t == "float") res = "(float)(sum / count)"; var typeStr = o.type; if (isNullable) { typeStr = "Nullable{" + o.type.Substring(0, 1).ToUpper() + o.type.Substring(1, o.type.Length - 2) + "}"; } #> /// /// Computes the average of an async-enumerable sequence of values. /// /// A sequence of values to calculate the average of. /// The optional cancellation token to be used for cancelling the sequence at any time. /// An async-enumerable sequence containing a single element with the average of the sequence of values. /// is null. /// (Asynchronous) The source sequence is empty. public static ValueTask<<#=o.res#>> AverageAsync(this IAsyncEnumerable<<#=o.type#>> source, CancellationToken cancellationToken = default) { if (source == null) throw Error.ArgumentNull(nameof(source)); return Core(source, cancellationToken); static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable<<#=o.type#>> source, CancellationToken cancellationToken) { <# if (isNullable) { #> await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false)) { while (await e.MoveNextAsync()) { var v = e.Current; if (v.HasValue) { <#=o.sum#> sum = v.GetValueOrDefault(); long count = 1; checked { while (await e.MoveNextAsync()) { v = e.Current; if (v.HasValue) { sum += v.GetValueOrDefault(); ++count; } } } return <#=res#>; } } } return null; <# } else { #> await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false)) { if (!await e.MoveNextAsync()) { throw Error.NoElements(); } <#=o.sum#> sum = e.Current; long count = 1; checked { while (await e.MoveNextAsync()) { sum += e.Current; ++count; } } return <#=res#>; } <# } #> } } /// /// Computes the average of an async-enumerable sequence of values that are obtained by invoking a transform function on each element of the input sequence. /// /// The type of the elements in the source sequence. /// A sequence of values to calculate the average of. /// A transform function to apply to each element. /// The optional cancellation token to be used for cancelling the sequence at any time. /// An async-enumerable sequence containing a single element with the average of the sequence of values, or null if the source sequence is empty or contains only values that are null. /// or is null. /// (Asynchronous) The source sequence is empty. /// The return type of this operator differs from the corresponding operator on IEnumerable in order to retain asynchronous behavior. public static ValueTask<<#=o.res#>> AverageAsync(this IAsyncEnumerable source, Func> selector, CancellationToken cancellationToken = default) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); return Core(source, selector, cancellationToken); static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable source, Func> selector, CancellationToken cancellationToken) { <# if (isNullable) { #> await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false)) { while (await e.MoveNextAsync()) { var v = selector(e.Current); if (v.HasValue) { <#=o.sum#> sum = v.GetValueOrDefault(); long count = 1; checked { while (await e.MoveNextAsync()) { v = selector(e.Current); if (v.HasValue) { sum += v.GetValueOrDefault(); ++count; } } } return <#=res#>; } } } return null; <# } else { #> await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false)) { if (!await e.MoveNextAsync()) { throw Error.NoElements(); } <#=o.sum#> sum = selector(e.Current); long count = 1; checked { while (await e.MoveNextAsync()) { sum += selector(e.Current); ++count; } } return <#=res#>; } <# } #> } } internal static ValueTask<<#=o.res#>> AverageAwaitAsyncCore(this IAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken = default) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); return Core(source, selector, cancellationToken); static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken) { <# if (isNullable) { #> await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false)) { while (await e.MoveNextAsync()) { var v = await selector(e.Current).ConfigureAwait(false); if (v.HasValue) { <#=o.sum#> sum = v.GetValueOrDefault(); long count = 1; checked { while (await e.MoveNextAsync()) { v = await selector(e.Current).ConfigureAwait(false); if (v.HasValue) { sum += v.GetValueOrDefault(); ++count; } } } return <#=res#>; } } } return null; <# } else { #> await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false)) { if (!await e.MoveNextAsync()) { throw Error.NoElements(); } <#=o.sum#> sum = await selector(e.Current).ConfigureAwait(false); long count = 1; checked { while (await e.MoveNextAsync()) { sum += await selector(e.Current).ConfigureAwait(false); ++count; } } return <#=res#>; } <# } #> } } #if !NO_DEEP_CANCELLATION internal static ValueTask<<#=o.res#>> AverageAwaitWithCancellationAsyncCore(this IAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken = default) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); return Core(source, selector, cancellationToken); static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken) { <# if (isNullable) { #> await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false)) { while (await e.MoveNextAsync()) { var v = await selector(e.Current, cancellationToken).ConfigureAwait(false); if (v.HasValue) { <#=o.sum#> sum = v.GetValueOrDefault(); long count = 1; checked { while (await e.MoveNextAsync()) { v = await selector(e.Current, cancellationToken).ConfigureAwait(false); if (v.HasValue) { sum += v.GetValueOrDefault(); ++count; } } } return <#=res#>; } } } return null; <# } else { #> await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false)) { if (!await e.MoveNextAsync()) { throw Error.NoElements(); } <#=o.sum#> sum = await selector(e.Current, cancellationToken).ConfigureAwait(false); long count = 1; checked { while (await e.MoveNextAsync()) { sum += await selector(e.Current, cancellationToken).ConfigureAwait(false); ++count; } } return <#=res#>; } <# } #> } } #endif <# } #> } }