#@ template debug="false" hostspecific="false" language="C#" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ output extension=".cs" #>
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace System.Linq
{
public static partial class AsyncEnumerable
{
<#
var os = new[]
{
new { type = "int", res = "double", sum = "long" },
new { type = "long", res = "double", sum = "long" },
new { type = "float", res = "float", sum = "double" },
new { type = "double", res = "double", sum = "double" },
new { type = "decimal", res = "decimal", sum = "decimal" },
new { type = "int?", res = "double?", sum = "long" },
new { type = "long?", res = "double?", sum = "long" },
new { type = "float?", res = "float?", sum = "double" },
new { type = "double?", res = "double?", sum = "double" },
new { type = "decimal?", res = "decimal?", sum = "decimal" },
};
foreach (var o in os)
{
var isNullable = o.type.EndsWith("?");
var t = o.type.TrimEnd('?');
string res = "";
if (t == "int" || t == "long")
res = "(double)sum / count";
else if (t == "double" || t == "decimal")
res = "sum / count";
else if (t == "float")
res = "(float)(sum / count)";
var typeStr = o.type;
if (isNullable) {
typeStr = "Nullable{" + o.type.Substring(0, 1).ToUpper() + o.type.Substring(1, o.type.Length - 2) + "}";
}
#>
///
/// Computes the average of an async-enumerable sequence of values.
///
/// A sequence of values to calculate the average of.
/// The optional cancellation token to be used for cancelling the sequence at any time.
/// An async-enumerable sequence containing a single element with the average of the sequence of values.
/// is null.
/// (Asynchronous) The source sequence is empty.
public static ValueTask<<#=o.res#>> AverageAsync(this IAsyncEnumerable<<#=o.type#>> source, CancellationToken cancellationToken = default)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
return Core(source, cancellationToken);
static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable<<#=o.type#>> source, CancellationToken cancellationToken)
{
<#
if (isNullable)
{
#>
await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
{
while (await e.MoveNextAsync())
{
var v = e.Current;
if (v.HasValue)
{
<#=o.sum#> sum = v.GetValueOrDefault();
long count = 1;
checked
{
while (await e.MoveNextAsync())
{
v = e.Current;
if (v.HasValue)
{
sum += v.GetValueOrDefault();
++count;
}
}
}
return <#=res#>;
}
}
}
return null;
<#
}
else
{
#>
await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
{
if (!await e.MoveNextAsync())
{
throw Error.NoElements();
}
<#=o.sum#> sum = e.Current;
long count = 1;
checked
{
while (await e.MoveNextAsync())
{
sum += e.Current;
++count;
}
}
return <#=res#>;
}
<#
}
#>
}
}
///
/// Computes the average of an async-enumerable sequence of values that are obtained by invoking a transform function on each element of the input sequence.
///
/// The type of the elements in the source sequence.
/// A sequence of values to calculate the average of.
/// A transform function to apply to each element.
/// The optional cancellation token to be used for cancelling the sequence at any time.
/// An async-enumerable sequence containing a single element with the average of the sequence of values, or null if the source sequence is empty or contains only values that are null.
/// or is null.
/// (Asynchronous) The source sequence is empty.
/// The return type of this operator differs from the corresponding operator on IEnumerable in order to retain asynchronous behavior.
public static ValueTask<<#=o.res#>> AverageAsync(this IAsyncEnumerable source, Func> selector, CancellationToken cancellationToken = default)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (selector == null)
throw Error.ArgumentNull(nameof(selector));
return Core(source, selector, cancellationToken);
static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable source, Func> selector, CancellationToken cancellationToken)
{
<#
if (isNullable)
{
#>
await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
{
while (await e.MoveNextAsync())
{
var v = selector(e.Current);
if (v.HasValue)
{
<#=o.sum#> sum = v.GetValueOrDefault();
long count = 1;
checked
{
while (await e.MoveNextAsync())
{
v = selector(e.Current);
if (v.HasValue)
{
sum += v.GetValueOrDefault();
++count;
}
}
}
return <#=res#>;
}
}
}
return null;
<#
}
else
{
#>
await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
{
if (!await e.MoveNextAsync())
{
throw Error.NoElements();
}
<#=o.sum#> sum = selector(e.Current);
long count = 1;
checked
{
while (await e.MoveNextAsync())
{
sum += selector(e.Current);
++count;
}
}
return <#=res#>;
}
<#
}
#>
}
}
internal static ValueTask<<#=o.res#>> AverageAwaitAsyncCore(this IAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken = default)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (selector == null)
throw Error.ArgumentNull(nameof(selector));
return Core(source, selector, cancellationToken);
static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken)
{
<#
if (isNullable)
{
#>
await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
{
while (await e.MoveNextAsync())
{
var v = await selector(e.Current).ConfigureAwait(false);
if (v.HasValue)
{
<#=o.sum#> sum = v.GetValueOrDefault();
long count = 1;
checked
{
while (await e.MoveNextAsync())
{
v = await selector(e.Current).ConfigureAwait(false);
if (v.HasValue)
{
sum += v.GetValueOrDefault();
++count;
}
}
}
return <#=res#>;
}
}
}
return null;
<#
}
else
{
#>
await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
{
if (!await e.MoveNextAsync())
{
throw Error.NoElements();
}
<#=o.sum#> sum = await selector(e.Current).ConfigureAwait(false);
long count = 1;
checked
{
while (await e.MoveNextAsync())
{
sum += await selector(e.Current).ConfigureAwait(false);
++count;
}
}
return <#=res#>;
}
<#
}
#>
}
}
#if !NO_DEEP_CANCELLATION
internal static ValueTask<<#=o.res#>> AverageAwaitWithCancellationAsyncCore(this IAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken = default)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (selector == null)
throw Error.ArgumentNull(nameof(selector));
return Core(source, selector, cancellationToken);
static async ValueTask<<#=o.res#>> Core(IAsyncEnumerable source, Func>> selector, CancellationToken cancellationToken)
{
<#
if (isNullable)
{
#>
await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
{
while (await e.MoveNextAsync())
{
var v = await selector(e.Current, cancellationToken).ConfigureAwait(false);
if (v.HasValue)
{
<#=o.sum#> sum = v.GetValueOrDefault();
long count = 1;
checked
{
while (await e.MoveNextAsync())
{
v = await selector(e.Current, cancellationToken).ConfigureAwait(false);
if (v.HasValue)
{
sum += v.GetValueOrDefault();
++count;
}
}
}
return <#=res#>;
}
}
}
return null;
<#
}
else
{
#>
await using (var e = source.GetConfiguredAsyncEnumerator(cancellationToken, false))
{
if (!await e.MoveNextAsync())
{
throw Error.NoElements();
}
<#=o.sum#> sum = await selector(e.Current, cancellationToken).ConfigureAwait(false);
long count = 1;
checked
{
while (await e.MoveNextAsync())
{
sum += await selector(e.Current, cancellationToken).ConfigureAwait(false);
++count;
}
}
return <#=res#>;
}
<#
}
#>
}
}
#endif
<#
}
#>
}
}