Browse Source

Support subscriptions created by matching in a generic context as well as from basic boolean unary expressions

(All this is very weak and the whole Subscription class should be rewritten at some point anyways.)
Romain Verdier 10 years ago
parent
commit
ee4234732d

+ 2 - 1
src/Abc.Zebus.Tests/Abc.Zebus.Tests.csproj

@@ -128,7 +128,8 @@
     <Compile Include="Lotus\ReplayMessageHandlerTests.cs" />
     <Compile Include="MessageContextTests.cs" />
     <Compile Include="Messages\FakeCommandWithTimestamp.cs" />
-	<Compile Include="Messages\FakeInfrastructureTransientCommand.cs" />
+    <Compile Include="Messages\FakeInfrastructureTransientCommand.cs" />
+    <Compile Include="Messages\FakeRoutableCommandWithBoolean.cs" />
     <Compile Include="Messages\FakeRoutableCommandWithEnum.cs">
       <SubType>Code</SubType>
     </Compile>

+ 21 - 0
src/Abc.Zebus.Tests/Messages/FakeRoutableCommandWithBoolean.cs

@@ -0,0 +1,21 @@
+using Abc.Zebus.Routing;
+using ProtoBuf;
+
+namespace Abc.Zebus.Tests.Messages
+{
+    [ProtoContract, Routable]
+    public class FakeRoutableCommandWithBoolean : ICommand
+    {
+        [ProtoMember(1, IsRequired = true), RoutingPosition(1)]
+        public bool IsAMatch;
+
+        FakeRoutableCommandWithBoolean()
+        {
+        }
+
+        public FakeRoutableCommandWithBoolean(bool isAMatch)
+        {
+            IsAMatch = isAMatch;
+        }
+    }
+}

+ 30 - 0
src/Abc.Zebus.Tests/SubscriptionTests.cs

@@ -138,6 +138,22 @@ namespace Abc.Zebus.Tests
             subscription.BindingKey.ShouldEqual(new BindingKey("12", "name", otherId.ToString()));
         }
 
+        [Test]
+        public void should_create_subscription_from_predicate_with_unary_expressions()
+        {
+            var subscription = Subscription.Matching<FakeRoutableCommandWithBoolean>(x => !x.IsAMatch);
+            subscription.MessageTypeId.ShouldEqual(new MessageTypeId(typeof(FakeRoutableCommandWithBoolean)));
+            subscription.BindingKey.ShouldEqual(new BindingKey("False"));
+
+            subscription = Subscription.Matching<FakeRoutableCommandWithBoolean>(x => x.IsAMatch);
+            subscription.MessageTypeId.ShouldEqual(new MessageTypeId(typeof(FakeRoutableCommandWithBoolean)));
+            subscription.BindingKey.ShouldEqual(new BindingKey("True"));
+            
+            subscription = Subscription.Matching<FakeRoutableCommandWithBoolean>(x => !(!x.IsAMatch));
+            subscription.MessageTypeId.ShouldEqual(new MessageTypeId(typeof(FakeRoutableCommandWithBoolean)));
+            subscription.BindingKey.ShouldEqual(new BindingKey("True"));
+        }
+
         [Test]
         public void should_create_subscription_from_predicate_with_enum()
         {
@@ -169,6 +185,20 @@ namespace Abc.Zebus.Tests
             subscription.MessageTypeId.ShouldEqual(new MessageTypeId(typeof(FakeRoutableCommand)));
             subscription.BindingKey.ShouldEqual(new BindingKey(GetFieldValue().ToString(), "*", "*"));
         }
+        
+        [Test]
+        public void should_create_subscription_from_simple_predicate_in_generic_context()
+        {
+            var subscription = CreateSubscription<FakeRoutableCommand>();
+            subscription.MessageTypeId.ShouldEqual(new MessageTypeId(typeof(FakeRoutableCommand)));
+            subscription.BindingKey.ShouldEqual(new BindingKey(GetFieldValue().ToString(), "*", "*"));
+        }
+
+        private Subscription CreateSubscription<TMessage>()
+            where TMessage : FakeRoutableCommand
+        {
+            return Subscription.Matching<TMessage>(x => x.Id == GetFieldValue());
+        }
 
         [Test]
         public void should_be_equatable()

+ 89 - 12
src/Abc.Zebus/Subscription.cs

@@ -120,11 +120,12 @@ namespace Abc.Zebus
                 throw new ArgumentException();
 
             var parameterValues = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
-            
+
             using (CultureScope.Invariant())
             {
                 var newExpression = (NewExpression)factory.Body;
                 var parameters = newExpression.Constructor.GetParameters();
+
                 for (var argumentIndex = 0; argumentIndex < newExpression.Arguments.Count; ++argumentIndex)
                 {
                     var argumentExpression = newExpression.Arguments[argumentIndex];
@@ -159,39 +160,99 @@ namespace Abc.Zebus
         public static Subscription Matching<TMessage>(Expression<Func<TMessage, bool>> predicate) where TMessage : IMessage
         {
             var fieldValues = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
-            
+
             using (CultureScope.Invariant())
             {
                 var current = predicate.Body;
                 while (current.NodeType == ExpressionType.And || current.NodeType == ExpressionType.AndAlso)
                 {
                     var binaryExpression = (BinaryExpression)current;
-                    AddFieldValue<TMessage>(fieldValues, (BinaryExpression)binaryExpression.Right);
+                    AddFieldValue<TMessage>(fieldValues, binaryExpression.Right);
                     current = binaryExpression.Left;
                 }
 
-                AddFieldValue<TMessage>(fieldValues, (BinaryExpression)current);
+                AddFieldValue<TMessage>(fieldValues, current);
             }
 
             return new Subscription(MessageUtil.TypeId<TMessage>(), BindingKey.Create(typeof(TMessage), fieldValues));
         }
 
-        private static void AddFieldValue<TMessage>(Dictionary<string, string> fieldValues, BinaryExpression expression)
+        private static void AddFieldValue<TMessage>(Dictionary<string, string> fieldValues, Expression expression)
+        {
+            var binaryExpression = expression as BinaryExpression;
+            if (binaryExpression != null)
+            {
+                AddFieldValueFromBinaryExpression<TMessage>(fieldValues, binaryExpression);
+                return;
+            }
+
+            var unaryExpression = expression as UnaryExpression;
+            if (unaryExpression != null)
+            {
+                AddFieldValueFromUnaryExpression<TMessage>(fieldValues, unaryExpression);
+                return;
+            }
+
+            var memberExpression = expression as MemberExpression;
+            if (memberExpression != null)
+            {
+                AddFieldValueFromMemberExpression<TMessage>(fieldValues, memberExpression);
+                return;
+            }
+
+            throw CreateArgumentException(expression);
+        }
+
+        private static void AddFieldValueFromUnaryExpression<T>(Dictionary<string, string> fieldValues, UnaryExpression unaryExpression)
+        {
+            if(unaryExpression.Type != typeof(bool))
+                throw CreateArgumentException(unaryExpression);
+
+            if(unaryExpression.NodeType != ExpressionType.Not)
+                throw CreateArgumentException(unaryExpression);
+
+            var currentFieldValue = false;
+
+            while (unaryExpression.Operand is UnaryExpression)
+            {
+                currentFieldValue = !currentFieldValue;
+                unaryExpression = (UnaryExpression)unaryExpression.Operand;
+            }
+
+            var memberExpression = unaryExpression.Operand as MemberExpression;
+            if (memberExpression == null)
+                throw CreateArgumentException(unaryExpression);
+
+            AddFieldValueFromMemberExpression<T>(fieldValues, memberExpression, currentFieldValue);
+        }
+
+        private static void AddFieldValueFromMemberExpression<TMessage>(Dictionary<string, string> fieldValues, MemberExpression memberExpression, bool fieldValue = true)
+        {
+            if (!IsMessageMemberExpression<TMessage>(memberExpression))
+                throw CreateArgumentException(memberExpression);
+
+            if (memberExpression.Type != typeof(bool))
+                throw CreateArgumentException(memberExpression);
+
+            fieldValues.Add(memberExpression.Member.Name, fieldValue.ToString());
+        }
+
+        private static void AddFieldValueFromBinaryExpression<TMessage>(Dictionary<string, string> fieldValues, BinaryExpression binaryExpression)
         {
             MemberExpression memberExpression;
             Expression memberValueExpression;
 
-            if (TryGetMessageMemberExpression<TMessage>(expression.Right, out memberExpression))
+            if (TryGetMessageMemberExpression<TMessage>(binaryExpression.Right, out memberExpression))
             {
-                memberValueExpression = expression.Left;
+                memberValueExpression = binaryExpression.Left;
             }
-            else if (TryGetMessageMemberExpression<TMessage>(expression.Left, out memberExpression))
+            else if (TryGetMessageMemberExpression<TMessage>(binaryExpression.Left, out memberExpression))
             {
-                memberValueExpression = expression.Right;
+                memberValueExpression = binaryExpression.Right;
             }
             else
             {
-                throw new ArgumentException("Invalid message predicate: " + expression);
+                throw CreateArgumentException(binaryExpression);
             }
 
             var memberName = memberExpression.Member.Name;
@@ -222,8 +283,24 @@ namespace Abc.Zebus
 
         private static bool IsMessageMemberExpression<TMessage>(MemberExpression memberExpression)
         {
-            var containingExpression = memberExpression.Expression as ParameterExpression;
-            return containingExpression != null && containingExpression.Type == typeof(TMessage);
+            var parameterExpression = memberExpression.Expression as ParameterExpression;
+            if (parameterExpression != null)
+                return parameterExpression.Type == typeof(TMessage);
+
+            var convertExpression = memberExpression.Expression as UnaryExpression;
+            if (convertExpression == null || convertExpression.NodeType != ExpressionType.Convert)
+                return false;
+
+            var typedParameterExpression = convertExpression.Operand as ParameterExpression;
+            if (typedParameterExpression == null)
+                return false;
+
+            return typedParameterExpression.Type.IsAssignableFrom(typeof(TMessage));
+        }
+
+        private static ArgumentException CreateArgumentException(Expression expression)
+        {
+            return new ArgumentException("Invalid message predicate: " + expression);
         }
 
         [EditorBrowsable(EditorBrowsableState.Never), UsedImplicitly]