Kaynağa Gözat

Convert DefaultIfEmpty to use iterator

Oren Novotny 9 yıl önce
ebeveyn
işleme
095995b4ef

+ 66 - 40
Ix.NET/Source/System.Interactive.Async/DefaultIfEmpty.cs

@@ -4,6 +4,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Linq;
 using System.Threading;
 using System.Threading.Tasks;
@@ -17,45 +18,7 @@ namespace System.Linq
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
 
-            return CreateEnumerable(
-                () =>
-                {
-                    var done = false;
-                    var hasElements = false;
-                    var e = source.GetEnumerator();
-                    var current = default(TSource);
-
-                    var cts = new CancellationTokenDisposable();
-                    var d = Disposable.Create(cts, e);
-
-                    var f = default(Func<CancellationToken, Task<bool>>);
-                    f = async ct =>
-                        {
-                            if (done)
-                                return false;
-                            if (await e.MoveNext(ct)
-                                       .ConfigureAwait(false))
-                            {
-                                hasElements = true;
-                                current = e.Current;
-                                return true;
-                            }
-                            done = true;
-                            if (!hasElements)
-                            {
-                                current = defaultValue;
-                                return true;
-                            }
-                            return false;
-                        };
-
-                    return CreateEnumerator(
-                        f,
-                        () => current,
-                        d.Dispose,
-                        e
-                    );
-                });
+            return new DefaultIfEmptyAsyncIterator<TSource>(source, defaultValue);
         }
 
         public static IAsyncEnumerable<TSource> DefaultIfEmpty<TSource>(this IAsyncEnumerable<TSource> source)
@@ -63,7 +26,70 @@ namespace System.Linq
             if (source == null)
                 throw new ArgumentNullException(nameof(source));
 
-            return source.DefaultIfEmpty(default(TSource));
+            return DefaultIfEmpty(source, default(TSource));
+        }
+
+        private sealed class DefaultIfEmptyAsyncIterator<TSource> : AsyncIterator<TSource>
+        {
+            private readonly IAsyncEnumerable<TSource> source;
+            private readonly TSource defaultValue;
+            private IAsyncEnumerator<TSource> enumerator;
+
+            public DefaultIfEmptyAsyncIterator(IAsyncEnumerable<TSource> source, TSource defaultValue)
+            {
+                this.source = source;
+                this.defaultValue = defaultValue;
+                Debug.Assert(source != null);
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new DefaultIfEmptyAsyncIterator<TSource>(source, defaultValue);
+            }
+
+            public override void Dispose()
+            {
+                if (enumerator != null)
+                {
+                    enumerator.Dispose();
+                    enumerator = null;
+                }
+
+                base.Dispose();
+            }
+
+            protected override async Task<bool> MoveNextCore(CancellationToken cancellationToken)
+            {
+                switch (state)
+                {
+                    case State.Allocated:
+                        enumerator = source.GetEnumerator();
+                        if (await enumerator.MoveNext(cancellationToken)
+                                            .ConfigureAwait(false))
+                        {
+                            current = enumerator.Current;
+                            state = State.Iterating;
+                        }
+                        else
+                        {
+                            current = defaultValue;
+                            state = State.Disposed; 
+                        }
+                        return true;
+
+                    case State.Iterating:
+                        if (await enumerator.MoveNext(cancellationToken)
+                                            .ConfigureAwait(false))
+                        {
+                            current = enumerator.Current;
+                            return true;
+                        }
+                        break;
+                }
+
+                Dispose();
+                return false;
+            }
         }
     }
 }