<#@ template debug="false" hostspecific="false" language="C#" #> <#@ assembly name="System.Core" #> <#@ import namespace="System.Linq" #> <#@ import namespace="System.Text" #> <#@ import namespace="System.Collections.Generic" #> <#@ output extension=".cs" #> // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information. using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; namespace System.Linq { public static partial class AsyncEnumerable { <# var os = new[] { new { type = "int", res = "double", sum = "long" }, new { type = "long", res = "double", sum = "long" }, new { type = "float", res = "float", sum = "double" }, new { type = "double", res = "double", sum = "double" }, new { type = "decimal", res = "decimal", sum = "decimal" }, new { type = "int?", res = "double?", sum = "long" }, new { type = "long?", res = "double?", sum = "long" }, new { type = "float?", res = "float?", sum = "double" }, new { type = "double?", res = "double?", sum = "double" }, new { type = "decimal?", res = "decimal?", sum = "decimal" }, }; foreach (var o in os) { var isNullable = o.type.EndsWith("?"); var t = o.type.TrimEnd('?'); string res = ""; if (t == "int" || t == "long") res = "(double)sum / count"; else if (t == "double" || t == "decimal") res = "sum / count"; else if (t == "float") res = "(float)(sum / count)"; #> public static ValueTask<<#=o.res#>> AverageAsync(this IAsyncEnumerable<<#=o.type#>> source, CancellationToken cancellationToken = default) { if (source == null) throw Error.ArgumentNull(nameof(source)); return Core(source, cancellationToken); static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable<<#=o.type#>> _source, CancellationToken _cancellationToken) { <# if (isNullable) { #> var e = _source.GetConfiguredAsyncEnumerator(_cancellationToken, false); try // REVIEW: Can use `await using` if we get pattern bind (HAS_AWAIT_USING_PATTERN_BIND) { while (await e.MoveNextAsync()) { var v = e.Current; if (v.HasValue) { <#=o.sum#> sum = v.GetValueOrDefault(); long count = 1; checked { while (await e.MoveNextAsync()) { v = e.Current; if (v.HasValue) { sum += v.GetValueOrDefault(); ++count; } } } return <#=res#>; } } } finally { await e.DisposeAsync(); } return null; <# } else { #> var e = _source.GetConfiguredAsyncEnumerator(_cancellationToken, false); try // REVIEW: Can use `await using` if we get pattern bind (HAS_AWAIT_USING_PATTERN_BIND) { if (!await e.MoveNextAsync()) { throw Error.NoElements(); } <#=o.sum#> sum = e.Current; long count = 1; checked { while (await e.MoveNextAsync()) { sum += e.Current; ++count; } } return <#=res#>; } finally { await e.DisposeAsync(); } <# } #> } } public static ValueTask<<#=o.res#>> AverageAsync(this IAsyncEnumerable source, Func> selector, CancellationToken cancellationToken = default) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); return Core(source, selector, cancellationToken); static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable _source, Func> _selector, CancellationToken _cancellationToken) { <# if (isNullable) { #> var e = _source.GetConfiguredAsyncEnumerator(_cancellationToken, false); try // REVIEW: Can use `await using` if we get pattern bind (HAS_AWAIT_USING_PATTERN_BIND) { while (await e.MoveNextAsync()) { var v = _selector(e.Current); if (v.HasValue) { <#=o.sum#> sum = v.GetValueOrDefault(); long count = 1; checked { while (await e.MoveNextAsync()) { v = _selector(e.Current); if (v.HasValue) { sum += v.GetValueOrDefault(); ++count; } } } return <#=res#>; } } } finally { await e.DisposeAsync(); } return null; <# } else { #> var e = _source.GetConfiguredAsyncEnumerator(_cancellationToken, false); try // REVIEW: Can use `await using` if we get pattern bind (HAS_AWAIT_USING_PATTERN_BIND) { if (!await e.MoveNextAsync()) { throw Error.NoElements(); } <#=o.sum#> sum = _selector(e.Current); long count = 1; checked { while (await e.MoveNextAsync()) { sum += _selector(e.Current); ++count; } } return <#=res#>; } finally { await e.DisposeAsync(); } <# } #> } } public static ValueTask<<#=o.res#>> AverageAsync(this IAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken = default) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); return Core(source, selector, cancellationToken); static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable _source, Func>> _selector, CancellationToken _cancellationToken) { <# if (isNullable) { #> var e = _source.GetConfiguredAsyncEnumerator(_cancellationToken, false); try // REVIEW: Can use `await using` if we get pattern bind (HAS_AWAIT_USING_PATTERN_BIND) { while (await e.MoveNextAsync()) { var v = await _selector(e.Current).ConfigureAwait(false); if (v.HasValue) { <#=o.sum#> sum = v.GetValueOrDefault(); long count = 1; checked { while (await e.MoveNextAsync()) { v = await _selector(e.Current).ConfigureAwait(false); if (v.HasValue) { sum += v.GetValueOrDefault(); ++count; } } } return <#=res#>; } } } finally { await e.DisposeAsync(); } return null; <# } else { #> var e = _source.GetConfiguredAsyncEnumerator(_cancellationToken, false); try // REVIEW: Can use `await using` if we get pattern bind (HAS_AWAIT_USING_PATTERN_BIND) { if (!await e.MoveNextAsync()) { throw Error.NoElements(); } <#=o.sum#> sum = await _selector(e.Current).ConfigureAwait(false); long count = 1; checked { while (await e.MoveNextAsync()) { sum += await _selector(e.Current).ConfigureAwait(false); ++count; } } return <#=res#>; } finally { await e.DisposeAsync(); } <# } #> } } #if !NO_DEEP_CANCELLATION public static ValueTask<<#=o.res#>> AverageAsync(this IAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken = default) { if (source == null) throw Error.ArgumentNull(nameof(source)); if (selector == null) throw Error.ArgumentNull(nameof(selector)); return Core(source, selector, cancellationToken); static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable _source, Func>> _selector, CancellationToken _cancellationToken) { <# if (isNullable) { #> var e = _source.GetConfiguredAsyncEnumerator(_cancellationToken, false); try // REVIEW: Can use `await using` if we get pattern bind (HAS_AWAIT_USING_PATTERN_BIND) { while (await e.MoveNextAsync()) { var v = await _selector(e.Current, _cancellationToken).ConfigureAwait(false); if (v.HasValue) { <#=o.sum#> sum = v.GetValueOrDefault(); long count = 1; checked { while (await e.MoveNextAsync()) { v = await _selector(e.Current, _cancellationToken).ConfigureAwait(false); if (v.HasValue) { sum += v.GetValueOrDefault(); ++count; } } } return <#=res#>; } } } finally { await e.DisposeAsync(); } return null; <# } else { #> var e = _source.GetConfiguredAsyncEnumerator(_cancellationToken, false); try // REVIEW: Can use `await using` if we get pattern bind (HAS_AWAIT_USING_PATTERN_BIND) { if (!await e.MoveNextAsync()) { throw Error.NoElements(); } <#=o.sum#> sum = await _selector(e.Current, _cancellationToken).ConfigureAwait(false); long count = 1; checked { while (await e.MoveNextAsync()) { sum += await _selector(e.Current, _cancellationToken).ConfigureAwait(false); ++count; } } return <#=res#>; } finally { await e.DisposeAsync(); } <# } #> } } #endif <# } #> } }