Przeglądaj źródła

HANDLE log message is now displayed just before running handler (it was displayed too early for async invocations)

Olivier Coanet 11 lat temu
rodzic
commit
15cf56963d

+ 27 - 0
src/Abc.Zebus.Tests/Core/BusManualTests.cs

@@ -2,6 +2,7 @@
 using System.Collections.Generic;
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Diagnostics;
 using System.Threading;
 using System.Threading;
+using System.Threading.Tasks;
 using ABC.ServiceBus.Contracts;
 using ABC.ServiceBus.Contracts;
 using Abc.Zebus.Core;
 using Abc.Zebus.Core;
 using Abc.Zebus.Dispatch;
 using Abc.Zebus.Dispatch;
@@ -119,6 +120,20 @@ namespace Abc.Zebus.Tests.Core
             }
             }
         }
         }
 
 
+        [Test]
+        public void SendSleepCommands()
+        {
+            var tasks = new List<Task>();
+            using (var bus = CreateBusFactory().WithHandlers(typeof(SleepCommandHandler)).CreateAndStartBus())
+            {
+                for (var i = 0; i < 20; ++i)
+                {
+                    tasks.Add(bus.Send(new SleepCommand()));
+                }
+                Task.WaitAll(tasks.ToArray());
+            }
+        }
+
         private static BusFactory CreateBusFactory()
         private static BusFactory CreateBusFactory()
         {
         {
             return new BusFactory().WithConfiguration(_directoryEndPoint, "Dev");
             return new BusFactory().WithConfiguration(_directoryEndPoint, "Dev");
@@ -201,5 +216,17 @@ namespace Abc.Zebus.Tests.Core
                 LastId = message.Id;
                 LastId = message.Id;
             }
             }
         }
         }
+
+        public class SleepCommand : ICommand
+        {
+        }
+
+        public class SleepCommandHandler : IMessageHandler<SleepCommand>
+        {
+            public void Handle(SleepCommand message)
+            {
+                Thread.Sleep(1000);
+            }
+        }
     }
     }
 }
 }

+ 9 - 0
src/Abc.Zebus.Tests/log4net.config

@@ -21,12 +21,17 @@
   <root>
   <root>
     <level value="INFO" />
     <level value="INFO" />
     <appender-ref ref="RollingFileAppender" />
     <appender-ref ref="RollingFileAppender" />
+    <appender-ref ref="ConsoleAppender" />
   </root>
   </root>
 
 
   <logger name="Abc.Zebus.Persistence.PersistentTransport">
   <logger name="Abc.Zebus.Persistence.PersistentTransport">
     <level value="DEBUG" />
     <level value="DEBUG" />
   </logger>
   </logger>
 
 
+  <logger name="Abc.Zebus.Scan.Pipes.PipeInvocation">
+    <level value="DEBUG" />
+  </logger>
+
   <logger name="Abc.Shared.Configuration.Proxy">
   <logger name="Abc.Shared.Configuration.Proxy">
     <level value="WARN" />
     <level value="WARN" />
   </logger>
   </logger>
@@ -54,4 +59,8 @@
   <logger name="Abc.Zebus.Tests.Core.BusPerformanceTests+PerfEvent">
   <logger name="Abc.Zebus.Tests.Core.BusPerformanceTests+PerfEvent">
     <level value="WARN" />
     <level value="WARN" />
   </logger>
   </logger>
+
+  <logger name="Abc.Zebus.Tests.Core.BusManualTests+SleepCommand">
+    <level value="DEBUG" />
+  </logger>
 </log4net>
 </log4net>

+ 5 - 6
src/Abc.Zebus/Core/Bus.cs

@@ -6,7 +6,6 @@ using System.Threading.Tasks;
 using Abc.Zebus.Directory;
 using Abc.Zebus.Directory;
 using Abc.Zebus.Dispatch;
 using Abc.Zebus.Dispatch;
 using Abc.Zebus.Lotus;
 using Abc.Zebus.Lotus;
-using Abc.Zebus.Monitoring;
 using Abc.Zebus.Serialization;
 using Abc.Zebus.Serialization;
 using Abc.Zebus.Transport;
 using Abc.Zebus.Transport;
 using Abc.Zebus.Util;
 using Abc.Zebus.Util;
@@ -346,7 +345,7 @@ namespace Abc.Zebus.Core
             if (dispatch == null)
             if (dispatch == null)
                 return;
                 return;
 
 
-            _messageLogger.LogFormat("RECV remote: {0} from {3} ({2} bytes). [{1}]", dispatch.Message, transportMessage.Id, transportMessage.MessageBytes.Length, transportMessage.Originator.SenderId);
+            _messageLogger.InfoFormat("RECV remote: {0} from {3} ({2} bytes). [{1}]", dispatch.Message, transportMessage.Id, transportMessage.MessageBytes.Length, transportMessage.Originator.SenderId);
             _messageDispatcher.Dispatch(dispatch);
             _messageDispatcher.Dispatch(dispatch);
         }
         }
 
 
@@ -359,7 +358,7 @@ namespace Abc.Zebus.Core
                 if (dispatch.Message is ICommand)
                 if (dispatch.Message is ICommand)
                 {
                 {
                     var messageExecutionCompleted = MessageExecutionCompleted.Create(dispatch.Context, dispatchResult, _serializer);
                     var messageExecutionCompleted = MessageExecutionCompleted.Create(dispatch.Context, dispatchResult, _serializer);
-                    var shouldLogMessageExecutionCompleted = _messageLogger.IsLogEnabled(dispatch.Message);
+                    var shouldLogMessageExecutionCompleted = _messageLogger.IsInfoEnabled(dispatch.Message);
                     SendTransportMessage(null, messageExecutionCompleted, dispatch.Context.GetSender(), shouldLogMessageExecutionCompleted);
                     SendTransportMessage(null, messageExecutionCompleted, dispatch.Context.GetSender(), shouldLogMessageExecutionCompleted);
                 }
                 }
 
 
@@ -416,7 +415,7 @@ namespace Abc.Zebus.Core
 
 
         protected virtual void HandleLocalMessage(IMessage message, TaskCompletionSource<CommandResult> taskCompletionSource)
         protected virtual void HandleLocalMessage(IMessage message, TaskCompletionSource<CommandResult> taskCompletionSource)
         {
         {
-            _messageLogger.LogFormat("RECV local: {0}", message);
+            _messageLogger.InfoFormat("RECV local: {0}", message);
 
 
             var context = MessageContext.CreateOverride(PeerId, EndPoint);
             var context = MessageContext.CreateOverride(PeerId, EndPoint);
             var dispatch = new MessageDispatch(context, message, GetOnLocalMessageDispatchedContinuation(taskCompletionSource));
             var dispatch = new MessageDispatch(context, message, GetOnLocalMessageDispatchedContinuation(taskCompletionSource));
@@ -448,7 +447,7 @@ namespace Abc.Zebus.Core
             if (peers.Count == 0)
             if (peers.Count == 0)
             {
             {
                 if (logEnabled)
                 if (logEnabled)
-                    _messageLogger.LogFormat("SEND: {0} with no target peer", message);
+                    _messageLogger.InfoFormat("SEND: {0} with no target peer", message);
 
 
                 return;
                 return;
             }
             }
@@ -456,7 +455,7 @@ namespace Abc.Zebus.Core
             var transportMessage = ToTransportMessage(message, messageId ?? MessageId.NextId());
             var transportMessage = ToTransportMessage(message, messageId ?? MessageId.NextId());
 
 
             if (logEnabled)
             if (logEnabled)
-                _messageLogger.LogFormat("SEND: {0} to {3} ({2} bytes) [{1}]", message, transportMessage.Id, transportMessage.MessageBytes.Length, peers);
+                _messageLogger.InfoFormat("SEND: {0} to {3} ({2} bytes) [{1}]", message, transportMessage.Id, transportMessage.MessageBytes.Length, peers);
 
 
             SendTransportMessage(transportMessage, peers);
             SendTransportMessage(transportMessage, peers);
         }
         }

+ 113 - 91
src/Abc.Zebus/Core/BusMessageLogger.cs

@@ -1,20 +1,26 @@
-using System;
-using System.Collections.Concurrent;
-using System.Collections.Generic;
-using System.Reflection;
-using Abc.Zebus.Util.Annotations;
-using Abc.Zebus.Util.Extensions;
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Reflection;
+using Abc.Zebus.Util.Annotations;
+using Abc.Zebus.Util.Extensions;
 using log4net;
 using log4net;
 using log4net.Core;
 using log4net.Core;
 
 
-namespace Abc.Zebus.Core
-{
-    public class BusMessageLogger
+namespace Abc.Zebus.Core
+{
+    public class BusMessageLogger
     {
     {
+        private static readonly ConcurrentDictionary<Type, MessageTypeLogInfo> _logInfos = new ConcurrentDictionary<Type, MessageTypeLogInfo>();
+        private static readonly Func<Type, MessageTypeLogInfo> _logInfoFactory;
         private readonly Type _loggerType;
         private readonly Type _loggerType;
-        private static readonly ConcurrentDictionary<Type, MessageTypeLogInfo> _logInfos = new ConcurrentDictionary<Type, MessageTypeLogInfo>();
         private readonly ILog _logger;
         private readonly ILog _logger;
 
 
+        static BusMessageLogger()
+        {
+            _logInfoFactory = CreateLogger;
+        }
+
         private BusMessageLogger(Type loggerType)
         private BusMessageLogger(Type loggerType)
         {
         {
             _loggerType = loggerType;
             _loggerType = loggerType;
@@ -24,94 +30,110 @@ namespace Abc.Zebus.Core
         public static BusMessageLogger Get<T>()
         public static BusMessageLogger Get<T>()
         {
         {
             return Instance<T>.Value;
             return Instance<T>.Value;
-        }
-
-        public bool IsLogEnabled(IMessage message)
-        {
-            var logInfo = _logInfos.GetOrAdd(message.GetType(), CreateLogger);
-            return logInfo.Logger.IsInfoEnabled;
-        }
-
-        [StringFormatMethod("format")]
-        public void LogFormat(string format, IMessage message, MessageId? messageId = null, int messageSize = 0, PeerId peerId = default(PeerId), Level logLevel = null)
-        {
-            var logInfo = _logInfos.GetOrAdd(message.GetType(), CreateLogger);
-            if (!logInfo.Logger.IsInfoEnabled)
-                return;
-
+        }
+
+        public bool IsInfoEnabled(IMessage message)
+        {
+            var logInfo = GetLogInfo(message);
+            return logInfo.Logger.IsInfoEnabled;
+        }
+
+        [StringFormatMethod("format")]
+        public void InfoFormat(string format, IMessage message, MessageId? messageId = null, int messageSize = 0, PeerId peerId = default(PeerId))
+        {
+            var logInfo = GetLogInfo(message);
+            if (!logInfo.Logger.IsInfoEnabled)
+                return;
+
+            var messageText = logInfo.GetMessageText(message);
+            _logger.InfoFormat(format, messageText, messageId, messageSize, peerId);
+        }
+
+        [StringFormatMethod("format")]
+        public void DebugFormat(string format, IMessage message, MessageId? messageId = null, int messageSize = 0, PeerId peerId = default(PeerId))
+        {
+            var logInfo = GetLogInfo(message);
+            if (!logInfo.Logger.IsDebugEnabled)
+                return;
+
+            var messageText = logInfo.GetMessageText(message);
+            _logger.DebugFormat(format, messageText, messageId, messageSize, peerId);
+        }
+
+        [StringFormatMethod("format")]
+        public void InfoFormat(string format, IMessage message, MessageId messageId, int messageSize, IList<Peer> peers, Level logLevel = null)
+        {
+            if (peers.Count == 0)
+            {
+                InfoFormat(format, message, messageId, messageSize);
+                return;
+            }
+            if (peers.Count == 1)
+            {
+                InfoFormat(format, message, messageId, messageSize, peerId: peers[0].Id);
+                return;
+            }
+
+            var logInfo = GetLogInfo(message);
+            if (!logInfo.Logger.IsInfoEnabled)
+                return;
+
             var messageText = logInfo.GetMessageText(message);
             var messageText = logInfo.GetMessageText(message);
-            _logger.Logger.Log(_loggerType, logLevel ?? Level.Info, string.Format(format, messageText, messageId, messageSize, peerId), null);
-        }
-
-        [StringFormatMethod("format")]
-        public void LogFormat(string format, IMessage message, MessageId messageId, int messageSize, IList<Peer> peers, Level logLevel = null)
-        {
-            if (peers.Count == 0)
-            {
-                LogFormat(format, message, messageId, messageSize);
-                return;
-            }
-            if (peers.Count == 1)
-            {
-                LogFormat(format, message, messageId, messageSize, peerId: peers[0].Id);
-                return;
-            }
-
-            var logInfo = _logInfos.GetOrAdd(message.GetType(), CreateLogger);
-            if (!logInfo.Logger.IsInfoEnabled)
-                return;
-
-            var messageText = logInfo.GetMessageText(message);
-            var otherPeersCount = peers.Count - 1;
-            var peerIdText = otherPeersCount > 1
+            var otherPeersCount = peers.Count - 1;
+            var peerIdText = otherPeersCount > 1
                 ? peers[0].Id + " and " + otherPeersCount + " other peers"
                 ? peers[0].Id + " and " + otherPeersCount + " other peers"
                 : peers[0].Id + " and " + otherPeersCount + " other peer";
                 : peers[0].Id + " and " + otherPeersCount + " other peer";
 
 
-            _logger.Logger.Log(_loggerType, logLevel ?? Level.Info, string.Format(format, messageText, messageId, messageSize, peerIdText), null);
-        }
-
-        public static string ToString(IMessage message)
-        {
-            var logInfo = _logInfos.GetOrAdd(message.GetType(), CreateLogger);
-            return logInfo.GetMessageText(message);
-        }
-
-        private static MessageTypeLogInfo CreateLogger(Type messageType)
-        {
-            var logger = LogManager.GetLogger(messageType);
-            var hasToStringOverride = HasToStringOverride(messageType);
-
-            return new MessageTypeLogInfo(logger, hasToStringOverride, messageType.GetPrettyName());
-        }
-
-        private static bool HasToStringOverride(Type messageType)
-        {
-            var methodInfo = messageType.GetMethod("ToString", BindingFlags.Instance | BindingFlags.Public | BindingFlags.DeclaredOnly);
-            return methodInfo != null;
-        }
-
-        private class MessageTypeLogInfo
-        {
-            public readonly ILog Logger;
-            private readonly bool _hasToStringOverride;
-            private readonly string _messageTypeName;
-
-            public MessageTypeLogInfo(ILog logger, bool hasToStringOverride, string messageTypeName)
-            {
-                Logger = logger;
-                _hasToStringOverride = hasToStringOverride;
-                _messageTypeName = messageTypeName;
-            }
-
-            public string GetMessageText(IMessage message)
-            {
-                return _hasToStringOverride ? string.Format("{0} {{{1}}}", _messageTypeName, message) : string.Format("{0}", _messageTypeName);
-            }
+            _logger.Logger.Log(_loggerType, logLevel ?? Level.Info, string.Format(format, messageText, messageId, messageSize, peerIdText), null);
+        }
+
+        public static string ToString(IMessage message)
+        {
+            var logInfo = GetLogInfo(message);
+            return logInfo.GetMessageText(message);
+        }
+
+        private static MessageTypeLogInfo GetLogInfo(IMessage message)
+        {
+            return _logInfos.GetOrAdd(message.GetType(), _logInfoFactory);
+        }
+
+        private static MessageTypeLogInfo CreateLogger(Type messageType)
+        {
+            var logger = LogManager.GetLogger(messageType);
+            var hasToStringOverride = HasToStringOverride(messageType);
+
+            return new MessageTypeLogInfo(logger, hasToStringOverride, messageType.GetPrettyName());
+        }
+
+        private static bool HasToStringOverride(Type messageType)
+        {
+            var methodInfo = messageType.GetMethod("ToString", BindingFlags.Instance | BindingFlags.Public | BindingFlags.DeclaredOnly);
+            return methodInfo != null;
+        }
+
+        private class MessageTypeLogInfo
+        {
+            public readonly ILog Logger;
+            private readonly bool _hasToStringOverride;
+            private readonly string _messageTypeName;
+
+            public MessageTypeLogInfo(ILog logger, bool hasToStringOverride, string messageTypeName)
+            {
+                Logger = logger;
+                _hasToStringOverride = hasToStringOverride;
+                _messageTypeName = messageTypeName;
+            }
+
+            public string GetMessageText(IMessage message)
+            {
+                return _hasToStringOverride ? string.Format("{0} {{{1}}}", _messageTypeName, message) : string.Format("{0}", _messageTypeName);
+            }
         }
         }
 
 
         private static class Instance<T>
         private static class Instance<T>
         {
         {
             public static readonly BusMessageLogger Value = new BusMessageLogger(typeof(T));
             public static readonly BusMessageLogger Value = new BusMessageLogger(typeof(T));
-        }
-    }
+        }
+    }
 }
 }

+ 131 - 130
src/Abc.Zebus/Scan/Pipes/PipeInvocation.cs

@@ -1,137 +1,138 @@
-using System;
-using System.Collections.Generic;
+using System;
+using System.Collections.Generic;
 using System.Threading.Tasks;
 using System.Threading.Tasks;
 using Abc.Zebus.Core;
 using Abc.Zebus.Core;
-using Abc.Zebus.Dispatch;
+using Abc.Zebus.Dispatch;
 using Abc.Zebus.Util.Extensions;
 using Abc.Zebus.Util.Extensions;
-using log4net.Core;
-
-namespace Abc.Zebus.Scan.Pipes
-{
-    public class PipeInvocation : IMessageHandlerInvocation
-    {
-        private readonly List<Action<object>> _handlerMutations = new List<Action<object>>();
-        private readonly IMessageHandlerInvoker _invoker;
-        private readonly IMessage _message;
-        private readonly MessageContext _messageContext;
+
+namespace Abc.Zebus.Scan.Pipes
+{
+    public class PipeInvocation : IMessageHandlerInvocation
+    {
+        private readonly List<Action<object>> _handlerMutations = new List<Action<object>>();
+        private readonly IMessageHandlerInvoker _invoker;
+        private readonly IMessage _message;
+        private readonly MessageContext _messageContext;
         private readonly IList<IPipe> _pipes;
         private readonly IList<IPipe> _pipes;
-        private readonly BusMessageLogger _messageLogger = BusMessageLogger.Get<PipeInvocation>();
-        
-        private object[] _pipeStates;
-
-        public PipeInvocation(IMessageHandlerInvoker invoker, IMessage message, MessageContext messageContext, IEnumerable<IPipe> pipes)
-        {
-            _invoker = invoker;
-            _message = message;
-            _messageContext = messageContext;
-            _pipes = pipes.AsList();
-        }
-
-        internal IList<IPipe> Pipes
-        {
-            get { return _pipes; }
-        }
-
-        public IMessageHandlerInvoker Invoker
-        {
-            get { return _invoker; }
-        }
-
-        public IMessage Message
-        {
-            get { return _message; }
-        }
-
-        public MessageContext Context
-        {
-            get { return _messageContext; }
-        }
-
-        public void AddHandlerMutation(Action<object> action)
-        {
-            _handlerMutations.Add(action);
-        }
-
+        private readonly BusMessageLogger _messageLogger = BusMessageLogger.Get<PipeInvocation>();
+        
+        private object[] _pipeStates;
+
+        public PipeInvocation(IMessageHandlerInvoker invoker, IMessage message, MessageContext messageContext, IEnumerable<IPipe> pipes)
+        {
+            _invoker = invoker;
+            _message = message;
+            _messageContext = messageContext;
+            _pipes = pipes.AsList();
+        }
+
+        internal IList<IPipe> Pipes
+        {
+            get { return _pipes; }
+        }
+
+        public IMessageHandlerInvoker Invoker
+        {
+            get { return _invoker; }
+        }
+
+        public IMessage Message
+        {
+            get { return _message; }
+        }
+
+        public MessageContext Context
+        {
+            get { return _messageContext; }
+        }
+
+        public void AddHandlerMutation(Action<object> action)
+        {
+            _handlerMutations.Add(action);
+        }
+
         protected internal virtual void Run()
         protected internal virtual void Run()
         {
         {
-            _messageLogger.LogFormat("HANDLE : {0} [{1}]", _message, _messageContext.MessageId, logLevel: Level.Debug);
-            _pipeStates = BeforeInvoke();
-
-            try
-            {
-                _invoker.InvokeMessageHandler(this);
-            }
-            catch (Exception exception)
-            {
-                AfterInvoke(_pipeStates, true, exception);
-                throw;
-            }
-
-            AfterInvoke(_pipeStates, false, null);
-        }
-
-        private object[] BeforeInvoke()
-        {
-            var stateRef = new BeforeInvokeArgs.StateRef();
-            var pipeStates = new object[_pipes.Count];
-            for (var pipeIndex = 0; pipeIndex < _pipes.Count; ++pipeIndex)
-            {
-                var beforeInvokeArgs = new BeforeInvokeArgs(this, stateRef);
-                _pipes[pipeIndex].BeforeInvoke(beforeInvokeArgs);
-                pipeStates[pipeIndex] = beforeInvokeArgs.State;
-            }
-            return pipeStates;
-        }
-
-        private void AfterInvoke(object[] pipeStates, bool isFaulted, Exception exception)
-        {
-            if (pipeStates == null)
-                throw new InvalidOperationException("Missing pipe states, did you call SetupForInvocation in your handler invoker?");
-
-            for (var pipeIndex = _pipes.Count - 1; pipeIndex >= 0; --pipeIndex)
-            {
-                var afterInvokeArgs = new AfterInvokeArgs(this, pipeStates[pipeIndex], isFaulted, exception);
-                _pipes[pipeIndex].AfterInvoke(afterInvokeArgs);
-            }
-        }
-
-        protected internal virtual Task RunAsync()
+            _pipeStates = BeforeInvoke();
+
+            try
+            {
+                _invoker.InvokeMessageHandler(this);
+            }
+            catch (Exception exception)
+            {
+                AfterInvoke(_pipeStates, true, exception);
+                throw;
+            }
+
+            AfterInvoke(_pipeStates, false, null);
+        }
+
+        private object[] BeforeInvoke()
         {
         {
-            _messageLogger.LogFormat("HANDLE : {0} [{1}]", _message, _messageContext.MessageId, logLevel: Level.Debug);
-            var runTask = _invoker.InvokeMessageHandlerAsync(this);
-            runTask.ContinueWith(task => AfterInvoke(_pipeStates, task.IsFaulted, task.Exception), TaskContinuationOptions.ExecuteSynchronously);
-
-            return runTask;
-        }
-
-        IDisposable IMessageHandlerInvocation.SetupForInvocation()
-        {
-            if (_pipeStates == null)
-                _pipeStates = BeforeInvoke();
-
-            return MessageContext.SetCurrent(_messageContext);
-        }
-
-        IDisposable IMessageHandlerInvocation.SetupForInvocation(object messageHandler)
-        {
-            if (_pipeStates == null)
-                _pipeStates = BeforeInvoke();
-
-            ApplyMutations(messageHandler);
-
-            return MessageContext.SetCurrent(_messageContext);
-        }
-
-        private void ApplyMutations(object messageHandler)
-        {
-            var messageContextAwareHandler = messageHandler as IMessageContextAware;
-            if (messageContextAwareHandler != null)
-                messageContextAwareHandler.Context = Context;
-
-            foreach (var messageHandlerMutation in _handlerMutations)
-            {
-                messageHandlerMutation(messageHandler);
-            }
-        }
-    }
-}
+            var stateRef = new BeforeInvokeArgs.StateRef();
+            var pipeStates = new object[_pipes.Count];
+            for (var pipeIndex = 0; pipeIndex < _pipes.Count; ++pipeIndex)
+            {
+                var beforeInvokeArgs = new BeforeInvokeArgs(this, stateRef);
+                _pipes[pipeIndex].BeforeInvoke(beforeInvokeArgs);
+                pipeStates[pipeIndex] = beforeInvokeArgs.State;
+            }
+            return pipeStates;
+        }
+
+        private void AfterInvoke(object[] pipeStates, bool isFaulted, Exception exception)
+        {
+            if (pipeStates == null)
+                throw new InvalidOperationException("Missing pipe states, did you call SetupForInvocation in your handler invoker?");
+
+            for (var pipeIndex = _pipes.Count - 1; pipeIndex >= 0; --pipeIndex)
+            {
+                var afterInvokeArgs = new AfterInvokeArgs(this, pipeStates[pipeIndex], isFaulted, exception);
+                _pipes[pipeIndex].AfterInvoke(afterInvokeArgs);
+            }
+        }
+
+        protected internal virtual Task RunAsync()
+        {
+            var runTask = _invoker.InvokeMessageHandlerAsync(this);
+            runTask.ContinueWith(task => AfterInvoke(_pipeStates, task.IsFaulted, task.Exception), TaskContinuationOptions.ExecuteSynchronously);
+
+            return runTask;
+        }
+
+        IDisposable IMessageHandlerInvocation.SetupForInvocation()
+        {
+            if (_pipeStates == null)
+                _pipeStates = BeforeInvoke();
+
+            _messageLogger.DebugFormat("HANDLE : {0} [{1}]", _message, _messageContext.MessageId);
+
+            return MessageContext.SetCurrent(_messageContext);
+        }
+
+        IDisposable IMessageHandlerInvocation.SetupForInvocation(object messageHandler)
+        {
+            if (_pipeStates == null)
+                _pipeStates = BeforeInvoke();
+
+            _messageLogger.DebugFormat("HANDLE : {0} [{1}]", _message, _messageContext.MessageId);
+
+            ApplyMutations(messageHandler);
+
+            return MessageContext.SetCurrent(_messageContext);
+        }
+
+        private void ApplyMutations(object messageHandler)
+        {
+            var messageContextAwareHandler = messageHandler as IMessageContextAware;
+            if (messageContextAwareHandler != null)
+                messageContextAwareHandler.Context = Context;
+
+            foreach (var messageHandlerMutation in _handlerMutations)
+            {
+                messageHandlerMutation(messageHandler);
+            }
+        }
+    }
+}