// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
namespace System.Linq
{
public static partial class EnumerableEx
{
///
/// Creates a sequence that retries enumerating the source sequence as long as an error occurs.
///
/// Source sequence element type.
/// Source sequence.
/// Sequence concatenating the results of the source sequence as long as an error occurs.
public static IEnumerable Retry(this IEnumerable source)
{
if (source == null)
{
throw new ArgumentNullException(nameof(source));
}
return RetryInfinite(source);
}
///
/// Creates a sequence that retries enumerating the source sequence as long as an error occurs, with the specified
/// maximum number of retries.
///
/// Source sequence element type.
/// Source sequence.
/// Maximum number of retries.
/// Sequence concatenating the results of the source sequence as long as an error occurs.
public static IEnumerable Retry(this IEnumerable source, int retryCount)
{
if (source == null)
{
throw new ArgumentNullException(nameof(source));
}
if (retryCount < 0)
{
throw new ArgumentOutOfRangeException(nameof(retryCount));
}
return RetryFinite(source, retryCount);
}
private static IEnumerable RetryInfinite(IEnumerable source)
{
while (true)
{
var enumerator = default(IEnumerator);
try
{
enumerator = source.GetEnumerator() ?? throw new NullReferenceException();
}
catch
{
continue;
}
using (enumerator)
{
var continueOuter = false;
while (true)
{
var v = default(TSource);
try
{
if (!enumerator.MoveNext())
{
yield break;
}
v = enumerator.Current;
}
catch
{
continueOuter = true;
break;
}
yield return v;
}
if (continueOuter)
{
continue;
}
}
}
}
private static IEnumerable RetryFinite(IEnumerable source, int retryCount)
{
var lastException = default(Exception);
for (var i = 0; i < retryCount; i++)
{
var enumerator = default(IEnumerator);
try
{
enumerator = source.GetEnumerator() ?? throw new NullReferenceException();
}
catch (Exception ex)
{
lastException = ex;
continue;
}
using (enumerator)
{
var continueOuter = false;
while (true)
{
var v = default(TSource);
try
{
if (!enumerator.MoveNext())
{
yield break;
}
v = enumerator.Current;
}
catch (Exception ex)
{
lastException = ex;
continueOuter = true;
break;
}
yield return v;
}
if (continueOuter)
{
continue;
}
}
}
if (lastException != null)
{
throw lastException;
}
}
}
}