Memoize.cs 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the MIT License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections;
  5. using System.Collections.Generic;
  6. namespace System.Linq
  7. {
  8. public static partial class EnumerableEx
  9. {
  10. /// <summary>
  11. /// Creates a buffer with a view over the source sequence, causing each enumerator to obtain access to all of the
  12. /// sequence's elements without causing multiple enumerations over the source.
  13. /// </summary>
  14. /// <typeparam name="TSource">Source sequence element type.</typeparam>
  15. /// <param name="source">Source sequence.</param>
  16. /// <returns>
  17. /// Buffer enabling each enumerator to retrieve all elements from the shared source sequence, without duplicating
  18. /// source enumeration side-effects.
  19. /// </returns>
  20. /// <example>
  21. /// var rng = Enumerable.Range(0, 10).Do(x => Console.WriteLine(x)).Memoize();
  22. /// var e1 = rng.GetEnumerator();
  23. /// Assert.IsTrue(e1.MoveNext()); // Prints 0
  24. /// Assert.AreEqual(0, e1.Current);
  25. /// Assert.IsTrue(e1.MoveNext()); // Prints 1
  26. /// Assert.AreEqual(1, e1.Current);
  27. /// var e2 = rng.GetEnumerator();
  28. /// Assert.IsTrue(e2.MoveNext()); // Doesn't print anything; the side-effect of Do
  29. /// Assert.AreEqual(0, e2.Current); // has already taken place during e1's iteration.
  30. /// Assert.IsTrue(e1.MoveNext()); // Prints 2
  31. /// Assert.AreEqual(2, e1.Current);
  32. /// </example>
  33. public static IBuffer<TSource> Memoize<TSource>(this IEnumerable<TSource> source)
  34. {
  35. if (source == null)
  36. throw new ArgumentNullException(nameof(source));
  37. return new MemoizedBuffer<TSource>(source.GetEnumerator());
  38. }
  39. /// <summary>
  40. /// Memoizes the source sequence within a selector function where each enumerator can get access to all of the
  41. /// sequence's elements without causing multiple enumerations over the source.
  42. /// </summary>
  43. /// <typeparam name="TSource">Source sequence element type.</typeparam>
  44. /// <typeparam name="TResult">Result sequence element type.</typeparam>
  45. /// <param name="source">Source sequence.</param>
  46. /// <param name="selector">Selector function with memoized access to the source sequence for each enumerator.</param>
  47. /// <returns>Sequence resulting from applying the selector function to the memoized view over the source sequence.</returns>
  48. public static IEnumerable<TResult> Memoize<TSource, TResult>(this IEnumerable<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> selector)
  49. {
  50. if (source == null)
  51. throw new ArgumentNullException(nameof(source));
  52. if (selector == null)
  53. throw new ArgumentNullException(nameof(selector));
  54. return Create(() => selector(source.Memoize()).GetEnumerator());
  55. }
  56. /// <summary>
  57. /// Creates a buffer with a view over the source sequence, causing a specified number of enumerators to obtain access
  58. /// to all of the sequence's elements without causing multiple enumerations over the source.
  59. /// </summary>
  60. /// <typeparam name="TSource">Source sequence element type.</typeparam>
  61. /// <param name="source">Source sequence.</param>
  62. /// <param name="readerCount">
  63. /// Number of enumerators that can access the underlying buffer. Once every enumerator has
  64. /// obtained an element from the buffer, the element is removed from the buffer.
  65. /// </param>
  66. /// <returns>
  67. /// Buffer enabling a specified number of enumerators to retrieve all elements from the shared source sequence,
  68. /// without duplicating source enumeration side-effects.
  69. /// </returns>
  70. public static IBuffer<TSource> Memoize<TSource>(this IEnumerable<TSource> source, int readerCount)
  71. {
  72. if (source == null)
  73. throw new ArgumentNullException(nameof(source));
  74. if (readerCount <= 0)
  75. throw new ArgumentOutOfRangeException(nameof(readerCount));
  76. return new MemoizedBuffer<TSource>(source.GetEnumerator(), readerCount);
  77. }
  78. /// <summary>
  79. /// Memoizes the source sequence within a selector function where a specified number of enumerators can get access to
  80. /// all of the sequence's elements without causing multiple enumerations over the source.
  81. /// </summary>
  82. /// <typeparam name="TSource">Source sequence element type.</typeparam>
  83. /// <typeparam name="TResult">Result sequence element type.</typeparam>
  84. /// <param name="source">Source sequence.</param>
  85. /// <param name="readerCount">
  86. /// Number of enumerators that can access the underlying buffer. Once every enumerator has
  87. /// obtained an element from the buffer, the element is removed from the buffer.
  88. /// </param>
  89. /// <param name="selector">
  90. /// Selector function with memoized access to the source sequence for a specified number of
  91. /// enumerators.
  92. /// </param>
  93. /// <returns>Sequence resulting from applying the selector function to the memoized view over the source sequence.</returns>
  94. public static IEnumerable<TResult> Memoize<TSource, TResult>(this IEnumerable<TSource> source, int readerCount, Func<IEnumerable<TSource>, IEnumerable<TResult>> selector)
  95. {
  96. if (source == null)
  97. throw new ArgumentNullException(nameof(source));
  98. if (readerCount <= 0)
  99. throw new ArgumentOutOfRangeException(nameof(readerCount));
  100. if (selector == null)
  101. throw new ArgumentNullException(nameof(selector));
  102. return Create(() => selector(source.Memoize(readerCount)).GetEnumerator());
  103. }
  104. private sealed class MemoizedBuffer<T> : IBuffer<T>
  105. {
  106. private readonly object _gate = new();
  107. private readonly IRefCountList<T> _buffer;
  108. private readonly IEnumerator<T> _source;
  109. private bool _disposed;
  110. private Exception? _error;
  111. private bool _stopped;
  112. public MemoizedBuffer(IEnumerator<T> source)
  113. : this(source, new MaxRefCountList<T>())
  114. {
  115. }
  116. public MemoizedBuffer(IEnumerator<T> source, int readerCount)
  117. : this(source, new RefCountList<T>(readerCount))
  118. {
  119. }
  120. private MemoizedBuffer(IEnumerator<T> source, IRefCountList<T> buffer)
  121. {
  122. _source = source;
  123. _buffer = buffer;
  124. }
  125. public IEnumerator<T> GetEnumerator()
  126. {
  127. if (_disposed)
  128. throw new ObjectDisposedException("");
  129. return GetEnumerator_();
  130. }
  131. IEnumerator IEnumerable.GetEnumerator()
  132. {
  133. if (_disposed)
  134. throw new ObjectDisposedException("");
  135. return GetEnumerator();
  136. }
  137. public void Dispose()
  138. {
  139. lock (_gate)
  140. {
  141. if (!_disposed)
  142. {
  143. _source.Dispose();
  144. _buffer.Clear();
  145. }
  146. _disposed = true;
  147. }
  148. }
  149. private IEnumerator<T> GetEnumerator_()
  150. {
  151. var i = 0;
  152. try
  153. {
  154. while (true)
  155. {
  156. if (_disposed)
  157. throw new ObjectDisposedException("");
  158. var hasValue = default(bool);
  159. var current = default(T)!;
  160. lock (_gate)
  161. {
  162. if (i >= _buffer.Count)
  163. {
  164. if (!_stopped)
  165. {
  166. try
  167. {
  168. hasValue = _source.MoveNext();
  169. if (hasValue)
  170. current = _source.Current;
  171. }
  172. catch (Exception ex)
  173. {
  174. _stopped = true;
  175. _error = ex;
  176. _source.Dispose();
  177. }
  178. }
  179. if (_stopped)
  180. {
  181. if (_error != null)
  182. throw _error;
  183. else
  184. break;
  185. }
  186. if (hasValue)
  187. {
  188. _buffer.Add(current);
  189. }
  190. }
  191. else
  192. {
  193. hasValue = true;
  194. }
  195. }
  196. if (hasValue)
  197. yield return _buffer[i];
  198. else
  199. break;
  200. i++;
  201. }
  202. }
  203. finally
  204. {
  205. _buffer?.Done(i + 1);
  206. }
  207. }
  208. }
  209. }
  210. }