浏览代码

IxCS8: Allow Share to be disposed

akarnokd 7 年之前
父节点
当前提交
14a1dba258
共有 1 个文件被更改,包括 69 次插入33 次删除
  1. 69 33
      Ix.NET/Source/System.Interactive/System/Linq/Operators/Share.cs

+ 69 - 33
Ix.NET/Source/System.Interactive/System/Linq/Operators/Share.cs

@@ -2,32 +2,35 @@
 // The .NET Foundation licenses this file to you under the Apache 2.0 License.
 // See the LICENSE file in the project root for more information. 
 
+using System;
 using System.Collections;
 using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
 
 namespace System.Linq
 {
     public static partial class EnumerableEx
     {
         /// <summary>
-        /// Creates a buffer with a shared view over the source sequence, causing each enumerator to fetch the next element
-        /// from the source sequence.
+        ///     Creates a buffer with a shared view over the source sequence, causing each enumerator to fetch the next element
+        ///     from the source sequence.
         /// </summary>
         /// <typeparam name="TSource">Source sequence element type.</typeparam>
         /// <param name="source">Source sequence.</param>
         /// <returns>Buffer enabling each enumerator to retrieve elements from the shared source sequence.</returns>
         /// <example>
-        /// var rng = Enumerable.Range(0, 10).Share();
-        /// var e1 = rng.GetEnumerator();    // Both e1 and e2 will consume elements from
-        /// var e2 = rng.GetEnumerator();    // the source sequence.
-        /// Assert.IsTrue(e1.MoveNext());
-        /// Assert.AreEqual(0, e1.Current);
-        /// Assert.IsTrue(e1.MoveNext());
-        /// Assert.AreEqual(1, e1.Current);
-        /// Assert.IsTrue(e2.MoveNext());    // e2 "steals" element 2
-        /// Assert.AreEqual(2, e2.Current);
-        /// Assert.IsTrue(e1.MoveNext());    // e1 can't see element 2
-        /// Assert.AreEqual(3, e1.Current);
+        ///     var rng = Enumerable.Range(0, 10).Share();
+        ///     var e1 = rng.GetEnumerator();    // Both e1 and e2 will consume elements from
+        ///     var e2 = rng.GetEnumerator();    // the source sequence.
+        ///     Assert.IsTrue(e1.MoveNext());
+        ///     Assert.AreEqual(0, e1.Current);
+        ///     Assert.IsTrue(e1.MoveNext());
+        ///     Assert.AreEqual(1, e1.Current);
+        ///     Assert.IsTrue(e2.MoveNext());    // e2 "steals" element 2
+        ///     Assert.AreEqual(2, e2.Current);
+        ///     Assert.IsTrue(e1.MoveNext());    // e1 can't see element 2
+        ///     Assert.AreEqual(3, e1.Current);
         /// </example>
         public static IBuffer<TSource> Share<TSource>(this IEnumerable<TSource> source)
         {
@@ -38,8 +41,8 @@ namespace System.Linq
         }
 
         /// <summary>
-        /// Shares the source sequence within a selector function where each enumerator can fetch the next element from the
-        /// source sequence.
+        ///     Shares the source sequence within a selector function where each enumerator can fetch the next element from the
+        ///     source sequence.
         /// </summary>
         /// <typeparam name="TSource">Source sequence element type.</typeparam>
         /// <typeparam name="TResult">Result sequence element type.</typeparam>
@@ -53,10 +56,11 @@ namespace System.Linq
             if (selector == null)
                 throw new ArgumentNullException(nameof(selector));
 
-            return Create(() => selector(source.Share()).GetEnumerator());
+            return Create(() => selector(source.Share())
+                              .GetEnumerator());
         }
 
-        private sealed class SharedBuffer<T> : IBuffer<T>
+        private class SharedBuffer<T> : IBuffer<T>
         {
             private bool _disposed;
             private IEnumerator<T> _source;
@@ -71,7 +75,7 @@ namespace System.Linq
                 if (_disposed)
                     throw new ObjectDisposedException("");
 
-                return GetEnumeratorCore();
+                return GetEnumerator_();
             }
 
             IEnumerator IEnumerable.GetEnumerator()
@@ -96,34 +100,66 @@ namespace System.Linq
                 }
             }
 
-            private IEnumerator<T> GetEnumeratorCore()
+            private IEnumerator<T> GetEnumerator_()
             {
-                while (true)
+                return new ShareEnumerator(this);
+            }
+
+            sealed class ShareEnumerator : IEnumerator<T>
+            {
+                readonly SharedBuffer<T> _parent;
+
+                T _current;
+
+                bool _disposed;
+
+                public ShareEnumerator(SharedBuffer<T> parent)
                 {
-                    if (_disposed)
-                        throw new ObjectDisposedException("");
+                    _parent = parent;
+                }
+
+                public T Current => _current;
+
+                object IEnumerator.Current => _current;
 
-                    var hasValue = default(bool);
-                    var current = default(T);
+                public void Dispose()
+                {
+                    _disposed = true;
+                }
 
-                    lock (_source)
+                public bool MoveNext()
+                {
+                    if (_disposed)
+                    {
+                        return false;
+                    }
+                    if (_parent._disposed)
                     {
-                        hasValue = _source.MoveNext();
+                        throw new ObjectDisposedException("");
+                    }
 
+                    var hasValue = false;
+                    var src = _parent._source;
+                    lock (src)
+                    {
+                        hasValue = src.MoveNext();
                         if (hasValue)
                         {
-                            current = _source.Current;
+                            _current = src.Current;
                         }
                     }
-
                     if (hasValue)
                     {
-                        yield return current;
-                    }
-                    else
-                    {
-                        break;
+                        return true;
                     }
+                    _disposed = true;
+                    _current = default(T);
+                    return false;
+                }
+
+                public void Reset()
+                {
+                    throw new NotSupportedException();
                 }
             }
         }