Sfoglia il codice sorgente

Async variants of Distinct and DistinctUntilChanged.

Bart De Smet 8 anni fa
parent
commit
9a752692dd

+ 282 - 16
Ix.NET/Source/System.Interactive.Async/Distinct.cs

@@ -51,6 +51,28 @@ namespace System.Linq
             return new DistinctAsyncIterator<TSource, TKey>(source, keySelector, comparer);
         }
 
+        public static IAsyncEnumerable<TSource> Distinct<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (keySelector == null)
+                throw new ArgumentNullException(nameof(keySelector));
+
+            return source.Distinct(keySelector, EqualityComparer<TKey>.Default);
+        }
+
+        public static IAsyncEnumerable<TSource> Distinct<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IEqualityComparer<TKey> comparer)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (keySelector == null)
+                throw new ArgumentNullException(nameof(keySelector));
+            if (comparer == null)
+                throw new ArgumentNullException(nameof(comparer));
+
+            return new DistinctAsyncIteratorWithTask<TSource, TKey>(source, keySelector, comparer);
+        }
+
         public static IAsyncEnumerable<TSource> DistinctUntilChanged<TSource>(this IAsyncEnumerable<TSource> source)
         {
             if (source == null)
@@ -91,11 +113,149 @@ namespace System.Linq
             return source.DistinctUntilChanged_(keySelector, comparer);
         }
 
+        public static IAsyncEnumerable<TSource> DistinctUntilChanged<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (keySelector == null)
+                throw new ArgumentNullException(nameof(keySelector));
+
+            return source.DistinctUntilChanged_(keySelector, EqualityComparer<TKey>.Default);
+        }
+
+        public static IAsyncEnumerable<TSource> DistinctUntilChanged<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IEqualityComparer<TKey> comparer)
+        {
+            if (source == null)
+                throw new ArgumentNullException(nameof(source));
+            if (keySelector == null)
+                throw new ArgumentNullException(nameof(keySelector));
+            if (comparer == null)
+                throw new ArgumentNullException(nameof(comparer));
+
+            return source.DistinctUntilChanged_(keySelector, comparer);
+        }
+
         private static IAsyncEnumerable<TSource> DistinctUntilChanged_<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey> comparer)
         {
             return new DistinctUntilChangedAsyncIterator<TSource, TKey>(source, keySelector, comparer);
         }
 
+        private static IAsyncEnumerable<TSource> DistinctUntilChanged_<TSource, TKey>(this IAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IEqualityComparer<TKey> comparer)
+        {
+            return new DistinctUntilChangedAsyncIteratorWithTask<TSource, TKey>(source, keySelector, comparer);
+        }
+
+        private sealed class DistinctAsyncIterator<TSource> : AsyncIterator<TSource>, IIListProvider<TSource>
+        {
+            private readonly IEqualityComparer<TSource> comparer;
+            private readonly IAsyncEnumerable<TSource> source;
+
+            private IAsyncEnumerator<TSource> enumerator;
+            private Set<TSource> set;
+
+            public DistinctAsyncIterator(IAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer)
+            {
+                Debug.Assert(source != null);
+
+                this.source = source;
+                this.comparer = comparer;
+            }
+
+            public async Task<TSource[]> ToArrayAsync(CancellationToken cancellationToken)
+            {
+                var s = await FillSetAsync(cancellationToken).ConfigureAwait(false);
+                return s.ToArray();
+            }
+
+            public async Task<List<TSource>> ToListAsync(CancellationToken cancellationToken)
+            {
+                var s = await FillSetAsync(cancellationToken)
+                            .ConfigureAwait(false);
+                return s.ToList();
+            }
+
+            public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
+            {
+                return onlyIfCheap ? -1 : (await FillSetAsync(cancellationToken).ConfigureAwait(false)).Count;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new DistinctAsyncIterator<TSource>(source, comparer);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (enumerator != null)
+                {
+                    await enumerator.DisposeAsync().ConfigureAwait(false);
+                    enumerator = null;
+                    set = null;
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetAsyncEnumerator();
+                        if (!await enumerator.MoveNextAsync().ConfigureAwait(false))
+                        {
+                            await DisposeAsync().ConfigureAwait(false);
+                            return false;
+                        }
+
+                        var element = enumerator.Current;
+                        set = new Set<TSource>(comparer);
+                        set.Add(element);
+                        current = element;
+
+                        state = AsyncIteratorState.Iterating;
+                        return true;
+
+                    case AsyncIteratorState.Iterating:
+                        while (await enumerator.MoveNextAsync().ConfigureAwait(false))
+                        {
+                            element = enumerator.Current;
+                            if (set.Add(element))
+                            {
+                                current = element;
+                                return true;
+                            }
+                        }
+
+                        break;
+                }
+
+                await DisposeAsync().ConfigureAwait(false);
+                return false;
+            }
+
+            private async Task<Set<TSource>> FillSetAsync(CancellationToken cancellationToken)
+            {
+                var s = new Set<TSource>(comparer);
+
+                var enu = source.GetAsyncEnumerator();
+
+                try
+                {
+                    while (await enu.MoveNextAsync(cancellationToken).ConfigureAwait(false))
+                    {
+                        s.Add(enu.Current);
+                    }
+                }
+                finally
+                {
+                    await enu.DisposeAsync().ConfigureAwait(false);
+                }
+
+                return s;
+            }
+        }
+
         private sealed class DistinctAsyncIterator<TSource, TKey> : AsyncIterator<TSource>, IIListProvider<TSource>
         {
             private readonly IEqualityComparer<TKey> comparer;
@@ -242,19 +402,23 @@ namespace System.Linq
             }
         }
 
-        private sealed class DistinctAsyncIterator<TSource> : AsyncIterator<TSource>, IIListProvider<TSource>
+        private sealed class DistinctAsyncIteratorWithTask<TSource, TKey> : AsyncIterator<TSource>, IIListProvider<TSource>
         {
-            private readonly IEqualityComparer<TSource> comparer;
+            private readonly IEqualityComparer<TKey> comparer;
+            private readonly Func<TSource, Task<TKey>> keySelector;
             private readonly IAsyncEnumerable<TSource> source;
 
             private IAsyncEnumerator<TSource> enumerator;
-            private Set<TSource> set;
+            private Set<TKey> set;
 
-            public DistinctAsyncIterator(IAsyncEnumerable<TSource> source, IEqualityComparer<TSource> comparer)
+            public DistinctAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IEqualityComparer<TKey> comparer)
             {
                 Debug.Assert(source != null);
+                Debug.Assert(keySelector != null);
+                Debug.Assert(comparer != null);
 
                 this.source = source;
+                this.keySelector = keySelector;
                 this.comparer = comparer;
             }
 
@@ -266,19 +430,44 @@ namespace System.Linq
 
             public async Task<List<TSource>> ToListAsync(CancellationToken cancellationToken)
             {
-                var s = await FillSetAsync(cancellationToken)
-                            .ConfigureAwait(false);
-                return s.ToList();
+                var s = await FillSetAsync(cancellationToken).ConfigureAwait(false);
+                return s;
             }
 
             public async Task<int> GetCountAsync(bool onlyIfCheap, CancellationToken cancellationToken)
             {
-                return onlyIfCheap ? -1 : (await FillSetAsync(cancellationToken).ConfigureAwait(false)).Count;
+                if (onlyIfCheap)
+                {
+                    return -1;
+                }
+
+                var count = 0;
+                var s = new Set<TKey>(comparer);
+
+                var enu = source.GetAsyncEnumerator();
+
+                try
+                {
+                    while (await enu.MoveNextAsync().ConfigureAwait(false))
+                    {
+                        var item = enu.Current;
+                        if (s.Add(await keySelector(item).ConfigureAwait(false)))
+                        {
+                            count++;
+                        }
+                    }
+                }
+                finally
+                {
+                    await enu.DisposeAsync().ConfigureAwait(false);
+                }
+
+                return count;
             }
 
             public override AsyncIterator<TSource> Clone()
             {
-                return new DistinctAsyncIterator<TSource>(source, comparer);
+                return new DistinctAsyncIteratorWithTask<TSource, TKey>(source, keySelector, comparer);
             }
 
             public override async Task DisposeAsync()
@@ -299,6 +488,7 @@ namespace System.Linq
                 {
                     case AsyncIteratorState.Allocated:
                         enumerator = source.GetAsyncEnumerator();
+
                         if (!await enumerator.MoveNextAsync().ConfigureAwait(false))
                         {
                             await DisposeAsync().ConfigureAwait(false);
@@ -306,8 +496,8 @@ namespace System.Linq
                         }
 
                         var element = enumerator.Current;
-                        set = new Set<TSource>(comparer);
-                        set.Add(element);
+                        set = new Set<TKey>(comparer);
+                        set.Add(await keySelector(element).ConfigureAwait(false));
                         current = element;
 
                         state = AsyncIteratorState.Iterating;
@@ -317,7 +507,7 @@ namespace System.Linq
                         while (await enumerator.MoveNextAsync().ConfigureAwait(false))
                         {
                             element = enumerator.Current;
-                            if (set.Add(element))
+                            if (set.Add(await keySelector(element).ConfigureAwait(false)))
                             {
                                 current = element;
                                 return true;
@@ -331,9 +521,10 @@ namespace System.Linq
                 return false;
             }
 
-            private async Task<Set<TSource>> FillSetAsync(CancellationToken cancellationToken)
+            private async Task<List<TSource>> FillSetAsync(CancellationToken cancellationToken)
             {
-                var s = new Set<TSource>(comparer);
+                var s = new Set<TKey>(comparer);
+                var r = new List<TSource>();
 
                 var enu = source.GetAsyncEnumerator();
 
@@ -341,7 +532,11 @@ namespace System.Linq
                 {
                     while (await enu.MoveNextAsync(cancellationToken).ConfigureAwait(false))
                     {
-                        s.Add(enu.Current);
+                        var item = enu.Current;
+                        if (s.Add(await keySelector(item).ConfigureAwait(false)))
+                        {
+                            r.Add(item);
+                        }
                     }
                 }
                 finally
@@ -349,7 +544,7 @@ namespace System.Linq
                     await enu.DisposeAsync().ConfigureAwait(false);
                 }
 
-                return s;
+                return r;
             }
         }
 
@@ -495,5 +690,76 @@ namespace System.Linq
                 return false;
             }
         }
+
+        private sealed class DistinctUntilChangedAsyncIteratorWithTask<TSource, TKey> : AsyncIterator<TSource>
+        {
+            private readonly IEqualityComparer<TKey> comparer;
+            private readonly Func<TSource, Task<TKey>> keySelector;
+            private readonly IAsyncEnumerable<TSource> source;
+            private TKey currentKeyValue;
+
+            private IAsyncEnumerator<TSource> enumerator;
+            private bool hasCurrentKey;
+
+            public DistinctUntilChangedAsyncIteratorWithTask(IAsyncEnumerable<TSource> source, Func<TSource, Task<TKey>> keySelector, IEqualityComparer<TKey> comparer)
+            {
+                this.source = source;
+                this.keySelector = keySelector;
+                this.comparer = comparer;
+            }
+
+            public override AsyncIterator<TSource> Clone()
+            {
+                return new DistinctUntilChangedAsyncIteratorWithTask<TSource, TKey>(source, keySelector, comparer);
+            }
+
+            public override async Task DisposeAsync()
+            {
+                if (enumerator != null)
+                {
+                    await enumerator.DisposeAsync().ConfigureAwait(false);
+                    enumerator = null;
+                    currentKeyValue = default(TKey);
+                }
+
+                await base.DisposeAsync().ConfigureAwait(false);
+            }
+
+            protected override async Task<bool> MoveNextCore()
+            {
+                switch (state)
+                {
+                    case AsyncIteratorState.Allocated:
+                        enumerator = source.GetAsyncEnumerator();
+                        state = AsyncIteratorState.Iterating;
+                        goto case AsyncIteratorState.Iterating;
+
+                    case AsyncIteratorState.Iterating:
+                        while (await enumerator.MoveNextAsync().ConfigureAwait(false))
+                        {
+                            var item = enumerator.Current;
+                            var key = await keySelector(item).ConfigureAwait(false);
+                            var comparerEquals = false;
+
+                            if (hasCurrentKey)
+                            {
+                                comparerEquals = comparer.Equals(currentKeyValue, key);
+                            }
+                            if (!hasCurrentKey || !comparerEquals)
+                            {
+                                hasCurrentKey = true;
+                                currentKeyValue = key;
+                                current = item;
+                                return true;
+                            }
+                        }
+
+                        break; // case
+                }
+
+                await DisposeAsync().ConfigureAwait(false);
+                return false;
+            }
+        }
     }
 }

+ 2 - 2
Ix.NET/Source/Tests/AsyncTests.Single.cs

@@ -3260,10 +3260,10 @@ namespace Tests
         public void DistinctKey_Null()
         {
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Distinct(default(IAsyncEnumerable<int>), x => x));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Distinct(AsyncEnumerable.Return(42), null));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Distinct(AsyncEnumerable.Return(42), default(Func<int, int>)));
 
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Distinct(default(IAsyncEnumerable<int>), x => x, EqualityComparer<int>.Default));
-            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Distinct(AsyncEnumerable.Return(42), null, EqualityComparer<int>.Default));
+            AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Distinct(AsyncEnumerable.Return(42), default(Func<int, int>), EqualityComparer<int>.Default));
             AssertThrows<ArgumentNullException>(() => AsyncEnumerable.Distinct(AsyncEnumerable.Return(42), x => x, null));
         }