Memoize.cs 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. // Licensed to the .NET Foundation under one or more agreements.
  2. // The .NET Foundation licenses this file to you under the Apache 2.0 License.
  3. // See the LICENSE file in the project root for more information.
  4. using System.Collections;
  5. using System.Collections.Generic;
  6. namespace System.Linq
  7. {
  8. public static partial class EnumerableEx
  9. {
  10. /// <summary>
  11. /// Creates a buffer with a view over the source sequence, causing each enumerator to obtain access to all of the
  12. /// sequence's elements without causing multiple enumerations over the source.
  13. /// </summary>
  14. /// <typeparam name="TSource">Source sequence element type.</typeparam>
  15. /// <param name="source">Source sequence.</param>
  16. /// <returns>
  17. /// Buffer enabling each enumerator to retrieve all elements from the shared source sequence, without duplicating
  18. /// source enumeration side-effects.
  19. /// </returns>
  20. /// <example>
  21. /// var rng = Enumerable.Range(0, 10).Do(x => Console.WriteLine(x)).Memoize();
  22. /// var e1 = rng.GetEnumerator();
  23. /// Assert.IsTrue(e1.MoveNext()); // Prints 0
  24. /// Assert.AreEqual(0, e1.Current);
  25. /// Assert.IsTrue(e1.MoveNext()); // Prints 1
  26. /// Assert.AreEqual(1, e1.Current);
  27. /// var e2 = rng.GetEnumerator();
  28. /// Assert.IsTrue(e2.MoveNext()); // Doesn't print anything; the side-effect of Do
  29. /// Assert.AreEqual(0, e2.Current); // has already taken place during e1's iteration.
  30. /// Assert.IsTrue(e1.MoveNext()); // Prints 2
  31. /// Assert.AreEqual(2, e1.Current);
  32. /// </example>
  33. public static IBuffer<TSource> Memoize<TSource>(this IEnumerable<TSource> source)
  34. {
  35. if (source == null)
  36. throw new ArgumentNullException(nameof(source));
  37. return new MemoizedBuffer<TSource>(source.GetEnumerator());
  38. }
  39. /// <summary>
  40. /// Memoizes the source sequence within a selector function where each enumerator can get access to all of the
  41. /// sequence's elements without causing multiple enumerations over the source.
  42. /// </summary>
  43. /// <typeparam name="TSource">Source sequence element type.</typeparam>
  44. /// <typeparam name="TResult">Result sequence element type.</typeparam>
  45. /// <param name="source">Source sequence.</param>
  46. /// <param name="selector">Selector function with memoized access to the source sequence for each enumerator.</param>
  47. /// <returns>Sequence resulting from applying the selector function to the memoized view over the source sequence.</returns>
  48. public static IEnumerable<TResult> Memoize<TSource, TResult>(this IEnumerable<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> selector)
  49. {
  50. if (source == null)
  51. throw new ArgumentNullException(nameof(source));
  52. if (selector == null)
  53. throw new ArgumentNullException(nameof(selector));
  54. return Create(() => selector(source.Memoize()).GetEnumerator());
  55. }
  56. /// <summary>
  57. /// Creates a buffer with a view over the source sequence, causing a specified number of enumerators to obtain access
  58. /// to all of the sequence's elements without causing multiple enumerations over the source.
  59. /// </summary>
  60. /// <typeparam name="TSource">Source sequence element type.</typeparam>
  61. /// <param name="source">Source sequence.</param>
  62. /// <param name="readerCount">
  63. /// Number of enumerators that can access the underlying buffer. Once every enumerator has
  64. /// obtained an element from the buffer, the element is removed from the buffer.
  65. /// </param>
  66. /// <returns>
  67. /// Buffer enabling a specified number of enumerators to retrieve all elements from the shared source sequence,
  68. /// without duplicating source enumeration side-effects.
  69. /// </returns>
  70. public static IBuffer<TSource> Memoize<TSource>(this IEnumerable<TSource> source, int readerCount)
  71. {
  72. if (source == null)
  73. throw new ArgumentNullException(nameof(source));
  74. if (readerCount <= 0)
  75. throw new ArgumentOutOfRangeException(nameof(readerCount));
  76. return new MemoizedBuffer<TSource>(source.GetEnumerator(), readerCount);
  77. }
  78. /// <summary>
  79. /// Memoizes the source sequence within a selector function where a specified number of enumerators can get access to
  80. /// all of the sequence's elements without causing multiple enumerations over the source.
  81. /// </summary>
  82. /// <typeparam name="TSource">Source sequence element type.</typeparam>
  83. /// <typeparam name="TResult">Result sequence element type.</typeparam>
  84. /// <param name="source">Source sequence.</param>
  85. /// <param name="readerCount">
  86. /// Number of enumerators that can access the underlying buffer. Once every enumerator has
  87. /// obtained an element from the buffer, the element is removed from the buffer.
  88. /// </param>
  89. /// <param name="selector">
  90. /// Selector function with memoized access to the source sequence for a specified number of
  91. /// enumerators.
  92. /// </param>
  93. /// <returns>Sequence resulting from applying the selector function to the memoized view over the source sequence.</returns>
  94. public static IEnumerable<TResult> Memoize<TSource, TResult>(this IEnumerable<TSource> source, int readerCount, Func<IEnumerable<TSource>, IEnumerable<TResult>> selector)
  95. {
  96. if (source == null)
  97. throw new ArgumentNullException(nameof(source));
  98. if (readerCount <= 0)
  99. throw new ArgumentOutOfRangeException(nameof(readerCount));
  100. if (selector == null)
  101. throw new ArgumentNullException(nameof(selector));
  102. return Create(() => selector(source.Memoize(readerCount)).GetEnumerator());
  103. }
  104. private sealed class MemoizedBuffer<T> : IBuffer<T>
  105. {
  106. private IRefCountList<T> _buffer;
  107. private bool _disposed;
  108. private Exception? _error;
  109. private IEnumerator<T> _source;
  110. private bool _stopped;
  111. public MemoizedBuffer(IEnumerator<T> source)
  112. : this(source, new MaxRefCountList<T>())
  113. {
  114. }
  115. public MemoizedBuffer(IEnumerator<T> source, int readerCount)
  116. : this(source, new RefCountList<T>(readerCount))
  117. {
  118. }
  119. private MemoizedBuffer(IEnumerator<T> source, IRefCountList<T> buffer)
  120. {
  121. _source = source;
  122. _buffer = buffer;
  123. }
  124. public IEnumerator<T> GetEnumerator()
  125. {
  126. if (_disposed)
  127. throw new ObjectDisposedException("");
  128. return GetEnumerator_();
  129. }
  130. IEnumerator IEnumerable.GetEnumerator()
  131. {
  132. if (_disposed)
  133. throw new ObjectDisposedException("");
  134. return GetEnumerator();
  135. }
  136. public void Dispose()
  137. {
  138. lock (_source)
  139. {
  140. if (!_disposed)
  141. {
  142. _source.Dispose();
  143. _source = null;
  144. _buffer.Clear();
  145. _buffer = null;
  146. }
  147. _disposed = true;
  148. }
  149. }
  150. private IEnumerator<T> GetEnumerator_()
  151. {
  152. var i = 0;
  153. try
  154. {
  155. while (true)
  156. {
  157. if (_disposed)
  158. throw new ObjectDisposedException("");
  159. var hasValue = default(bool);
  160. var current = default(T);
  161. lock (_source)
  162. {
  163. if (i >= _buffer.Count)
  164. {
  165. if (!_stopped)
  166. {
  167. try
  168. {
  169. hasValue = _source.MoveNext();
  170. if (hasValue)
  171. current = _source.Current;
  172. }
  173. catch (Exception ex)
  174. {
  175. _stopped = true;
  176. _error = ex;
  177. _source.Dispose();
  178. }
  179. }
  180. if (_stopped)
  181. {
  182. if (_error != null)
  183. throw _error;
  184. else
  185. break;
  186. }
  187. if (hasValue)
  188. {
  189. _buffer.Add(current);
  190. }
  191. }
  192. else
  193. {
  194. hasValue = true;
  195. }
  196. }
  197. if (hasValue)
  198. yield return _buffer[i];
  199. else
  200. break;
  201. i++;
  202. }
  203. }
  204. finally
  205. {
  206. if (_buffer != null)
  207. _buffer.Done(i + 1);
  208. }
  209. }
  210. }
  211. }
  212. }