Browse Source

Merge pull request #292 from Reactive-Extensions/join-fix

Join fix
Oren Novotny 9 years ago
parent
commit
cdcc1c9c7f

+ 7 - 4
Ix.NET/Source/System.Interactive.Async/Join.cs

@@ -97,7 +97,7 @@ namespace System.Linq
             TOuter item;
             private int mode;
 
-            const int State_Begin = 1;
+            const int State_If = 1;
             const int State_DoLoop = 2;
             const int State_For = 3;
             const int State_While = 4;
@@ -108,18 +108,19 @@ namespace System.Linq
                 {
                     case AsyncIteratorState.Allocated:
                         outerEnumerator = outer.GetEnumerator();
-                        mode = State_Begin;
+                        mode = State_If;
                         state = AsyncIteratorState.Iterating;
                         goto case AsyncIteratorState.Iterating;
 
                     case AsyncIteratorState.Iterating:
                         switch (mode)
                         {
-                            case State_Begin:
+                            case State_If:
                                 if (await outerEnumerator.MoveNext(cancellationToken)
                                                          .ConfigureAwait(false))
                                 {
                                     lookup = await Internal.Lookup<TKey, TInner>.CreateForJoinAsync(inner, innerKeySelector, comparer, cancellationToken).ConfigureAwait(false);
+
                                     if (lookup.Count != 0)
                                     {
                                         mode = State_DoLoop;
@@ -140,7 +141,9 @@ namespace System.Linq
                                     goto case State_For;
                                 }
 
-                                break;
+                                // advance to while
+                                mode = State_While;
+                                goto case State_While;
 
                             case State_For:
                                 current = resultSelector(item, elements[index]);

+ 121 - 0
Ix.NET/Source/Tests/AsyncTests.Multiple.cs

@@ -4,6 +4,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Linq;
 using System.Text;
 using Xunit;
@@ -891,6 +892,73 @@ namespace Tests
         }
 
 
+        [Fact]
+        public void Join11()
+        {
+            var customers = new List<Customer>
+            {
+                new Customer {CustomerId = "ALFKI"},
+                new Customer {CustomerId = "ANANT"},
+                new Customer {CustomerId = "FISSA"}
+            };
+            var orders = new List<Order>
+            {
+                new Order { OrderId = 1, CustomerId = "ALFKI"},
+                new Order { OrderId = 2, CustomerId = "ALFKI"},
+                new Order { OrderId = 3, CustomerId = "ALFKI"},
+                new Order { OrderId = 4, CustomerId = "FISSA"},
+                new Order { OrderId = 5, CustomerId = "FISSA"},
+                new Order { OrderId = 6, CustomerId = "FISSA"},
+            };
+
+            var asyncResult = customers.ToAsyncEnumerable()
+                                       .Join(orders.ToAsyncEnumerable(), c => c.CustomerId, o => o.CustomerId,
+                                            (c, o) => new CustomerOrder { CustomerId = c.CustomerId, OrderId = o.OrderId });
+
+            var e = asyncResult.GetEnumerator();
+            HasNext(e, new CustomerOrder { CustomerId = "ALFKI", OrderId = 1 });
+            HasNext(e, new CustomerOrder { CustomerId = "ALFKI", OrderId = 2 });
+            HasNext(e, new CustomerOrder { CustomerId = "ALFKI", OrderId = 3 });
+            HasNext(e, new CustomerOrder { CustomerId = "FISSA", OrderId = 4 });
+            HasNext(e, new CustomerOrder { CustomerId = "FISSA", OrderId = 5 });
+            HasNext(e, new CustomerOrder { CustomerId = "FISSA", OrderId = 6 });
+            NoNext(e);
+        }
+
+        [Fact]
+        public void Join12()
+        {
+            var customers = new List<Customer>
+            {
+                new Customer {CustomerId = "ANANT"},
+                new Customer {CustomerId = "ALFKI"},
+                new Customer {CustomerId = "FISSA"}
+            };
+            var orders = new List<Order>
+            {
+                new Order { OrderId = 1, CustomerId = "ALFKI"},
+                new Order { OrderId = 2, CustomerId = "ALFKI"},
+                new Order { OrderId = 3, CustomerId = "ALFKI"},
+                new Order { OrderId = 4, CustomerId = "FISSA"},
+                new Order { OrderId = 5, CustomerId = "FISSA"},
+                new Order { OrderId = 6, CustomerId = "FISSA"},
+            };
+
+            var asyncResult = customers.ToAsyncEnumerable()
+                                       .Join(orders.ToAsyncEnumerable(), c => c.CustomerId, o => o.CustomerId,
+                                            (c, o) => new CustomerOrder { CustomerId = c.CustomerId, OrderId = o.OrderId });
+
+            var e = asyncResult.GetEnumerator();
+            HasNext(e, new CustomerOrder { CustomerId = "ALFKI", OrderId = 1 });
+            HasNext(e, new CustomerOrder { CustomerId = "ALFKI", OrderId = 2 });
+            HasNext(e, new CustomerOrder { CustomerId = "ALFKI", OrderId = 3 });
+            HasNext(e, new CustomerOrder { CustomerId = "FISSA", OrderId = 4 });
+            HasNext(e, new CustomerOrder { CustomerId = "FISSA", OrderId = 5 });
+            HasNext(e, new CustomerOrder { CustomerId = "FISSA", OrderId = 6 });
+            NoNext(e);
+        }
+
+
         [Fact]
         public void SelectManyMultiple_Null()
         {
@@ -915,5 +983,58 @@ namespace Tests
             HasNext(e, 4);
             NoNext(e);
         }
+
+
+        public class Customer
+        {
+            public string CustomerId { get; set; }
+        }
+
+        public class Order
+        {
+            public int OrderId { get; set; }
+            public string CustomerId { get; set; }
+        }
+
+        [DebuggerDisplay("CustomerId = {CustomerId}, OrderId = {OrderId}")]
+        public class CustomerOrder : IEquatable<CustomerOrder>
+        {
+            public bool Equals(CustomerOrder other)
+            {
+                if (ReferenceEquals(null, other)) return false;
+                if (ReferenceEquals(this, other)) return true;
+                return OrderId == other.OrderId && string.Equals(CustomerId, other.CustomerId);
+            }
+
+            public override bool Equals(object obj)
+            {
+                if (ReferenceEquals(null, obj)) return false;
+                if (ReferenceEquals(this, obj)) return true;
+                if (obj.GetType() != this.GetType()) return false;
+                return Equals((CustomerOrder)obj);
+            }
+
+            public override int GetHashCode()
+            {
+                unchecked
+                {
+                    return (OrderId * 397) ^ (CustomerId != null ? CustomerId.GetHashCode() : 0);
+                }
+            }
+
+            public static bool operator ==(CustomerOrder left, CustomerOrder right)
+            {
+                return Equals(left, right);
+            }
+
+            public static bool operator !=(CustomerOrder left, CustomerOrder right)
+            {
+                return !Equals(left, right);
+            }
+
+            public int OrderId { get; set; }
+            public string CustomerId { get; set; }
+        }
+
     }
 }