Post

Publishing strategies in MediatR

EDIT: There are breaking changes in MediatR 12.0.1, and the following implementation won’t work. For v12, please check the new article here.

The Mediator pattern is a behavioral design pattern that promotes loose coupling between objects by having a central point of communication, which is called the mediator. This pattern is particularly useful when you have a complex system with multiple interacting components that need to communicate with each other. By introducing a mediator object, you can reduce the dependencies between these components, making it easier to maintain and evolve the system over time. In the Mediator pattern, components don’t interact with each other directly; instead, they send messages to the mediator, which then coordinates the communication between the components. This way, components can be added, removed, or modified without affecting other parts of the system. The mediator’s primary responsibility is to facilitate the interaction between the components and to ensure that they collaborate correctly.

One of the popular and widely used implementation in .NET is the MediatR library. It has support for various scenarios, including request/response, commands, queries, and notifications. In this article we’ll focus on notifications and how we can utilize different publishing strategies. The library by default calls and await each handler sequentially. In case of an exception, the execution is stopped, meaning the rest of the handlers won’t execute. We’d like to extend this behavior, and use a different strategy on demand. First, let’s define our requirements.

  • Implement multiple publishing strategies
  • We should be able to choose a strategy while publishing a notification.
  • Ideally, the feature should be an extension to IMediator. Consumers should not have to deal with new types.

Publishing Strategies

Let’s first create an enum with the strategies we plan on supporting.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
public enum PublishStrategy
{
    /// <summary>
    /// Run each notification handler after one another. Returns when all handlers are finished or an exception has been thrown. In case of an exception, any handlers after that will not be run.
    /// </summary>
    AsyncSequentialStopOnException = 0,

    /// <summary>
    /// Run each notification handler after one another. Returns when all handlers are finished. In case of any exception(s), they will be captured in an AggregateException.
    /// </summary>
    AsyncSequentialContinueOnException = 1,

    /// <summary>
    /// Run all notification handlers asynchronously. Returns when all handlers are finished. In case of any exception(s), they will be captured in an AggregateException.
    /// </summary>
    AsyncWhenAll = 2,

    /// <summary>
    /// Run each notification handler on its own thread using Task.Run(). Returns when all threads (handlers) are finished. In case of any exception(s), if the call to Publish is awaited, they are captured in an AggregateException by Task.WhenAll. Do not use this strategy if you're accessing the database in your handlers, DbContext is not thread-safe.
    /// </summary>
    ParallelWhenAll = 3,

    /// <summary>
    /// Create a single new thread using Task.Run(), and run all notifications sequentially (continue on exception). Returns immediately and does not wait for any handlers to finish. Note that you cannot capture any exceptions, even if you await the call to Publish. To improve the traceability the exception is being captured internally and logged with ILogger if available.
    /// </summary>
    AsyncNoWait = 4,

    /// <summary>
    /// Run each notification handler on its own thread using Task.Run(). Returns immediately and does not wait for any handlers to finish. Note that you cannot capture any exceptions, even if you await the call to Publish. To improve the traceability the exception is being captured internally and logged with ILogger if available. Do not use this strategy if you're accessing the database in your handlers, DbContext is not thread-safe.
    /// </summary>
    ParallelNoWait = 5,
}

Extending Mediator Implementation

Next, let’s define a custom mediator implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
public class ExtendedMediator : Mediator
{
    private readonly ServiceFactory _serviceFactory;
    private readonly IServiceScopeFactory _serviceScopeFactory;
    private readonly Func<IEnumerable<Func<INotification, CancellationToken, Task>>, INotification, CancellationToken, Task> _publish;

    private ExtendedMediator(
        ServiceFactory serviceFactory,
        Func<IEnumerable<Func<INotification, CancellationToken, Task>>, INotification, CancellationToken, Task> publish)
        : base(serviceFactory)
    {
        _serviceFactory = serviceFactory;
        _serviceScopeFactory = default!;
        _publish = publish;
    }

    public ExtendedMediator(ServiceFactory serviceFactory, IServiceScopeFactory serviceScopeFactory)
        : base(serviceFactory)
    {
        _serviceFactory = serviceFactory;
        _serviceScopeFactory = serviceScopeFactory;
        _publish = base.PublishCore;
    }

    protected override Task PublishCore(
        IEnumerable<Func<INotification, CancellationToken, Task>> allHandlers,
        INotification notification,
        CancellationToken cancellationToken)
    {
        return _publish(allHandlers, notification, cancellationToken);
    }

    public Task Publish<TNotification>(
        TNotification notification,
        PublishStrategy strategy,
        CancellationToken cancellationToken) where TNotification : INotification
    {
        return strategy switch
        {
            PublishStrategy.AsyncNoWait => PublishNoWait(_serviceScopeFactory, notification, AsyncSequentialContinueOnException, cancellationToken),
            PublishStrategy.ParallelNoWait => PublishNoWait(_serviceScopeFactory, notification, ParallelWhenAll, cancellationToken),
            PublishStrategy.AsyncSequentialContinueOnException => new ExtendedMediator(_serviceFactory, AsyncSequentialContinueOnException).Publish(notification, cancellationToken),
            PublishStrategy.AsyncSequentialStopOnException => new ExtendedMediator(_serviceFactory, AsyncSequentialStopOnException).Publish(notification, cancellationToken),
            PublishStrategy.AsyncWhenAll => new ExtendedMediator(_serviceFactory, AsyncWhenAll).Publish(notification, cancellationToken),
            PublishStrategy.ParallelWhenAll => new ExtendedMediator(_serviceFactory, ParallelWhenAll).Publish(notification, cancellationToken),
            _ => throw new ArgumentException($"Unknown strategy: {strategy}")
        };
    }

    private static Task PublishNoWait(
        IServiceScopeFactory serviceScopeFactory,
        INotification notification,
        Func<IEnumerable<Func<INotification, CancellationToken, Task>>, INotification, CancellationToken, Task> publish,
        CancellationToken cancellationToken)
    {
        _ = Task.Run(async () =>
        {
            using var scope = serviceScopeFactory.CreateScope();
            var logger = scope.ServiceProvider.GetService<ILogger<ExtendedMediator>>();
            try
            {
                var mediator = new ExtendedMediator(scope.ServiceProvider.GetRequiredService, publish);
                await mediator.Publish(notification, cancellationToken).ConfigureAwait(false);
            }
            catch (Exception ex)
            {
                if (logger is not null)
                {
                    logger.LogError(ex, "Error occurred while executing the handler in NoWait mode");
                }
            }
        }, cancellationToken);

        return Task.CompletedTask;
    }

    private static Task ParallelWhenAll(
        IEnumerable<Func<INotification, CancellationToken, Task>> handlers,
        INotification notification,
        CancellationToken cancellationToken)
    {
        var tasks = new List<Task>();

        foreach (var handler in handlers)
        {
            tasks.Add(Task.Run(() => handler(notification, cancellationToken), cancellationToken));
        }

        return Task.WhenAll(tasks);
    }

    private static async Task AsyncWhenAll(
        IEnumerable<Func<INotification, CancellationToken, Task>> handlers,
        INotification notification,
        CancellationToken cancellationToken)
    {
        var tasks = new List<Task>();
        var exceptions = new List<Exception>();

        foreach (var handler in handlers)
        {
            try
            {
                tasks.Add(handler(notification, cancellationToken));
            }
            catch (Exception ex) when (!(ex is OutOfMemoryException || ex is StackOverflowException))
            {
                exceptions.Add(ex);
            }
        }

        try
        {
            await Task.WhenAll(tasks).ConfigureAwait(false);
        }
        catch (AggregateException ex)
        {
            exceptions.AddRange(ex.Flatten().InnerExceptions);
        }
        catch (Exception ex) when (!(ex is OutOfMemoryException || ex is StackOverflowException))
        {
            exceptions.Add(ex);
        }

        if (exceptions.Any())
        {
            throw new AggregateException(exceptions);
        }
    }

    private static async Task AsyncSequentialContinueOnException(
        IEnumerable<Func<INotification, CancellationToken, Task>> handlers,
        INotification notification,
        CancellationToken cancellationToken)
    {
        var exceptions = new List<Exception>();

        foreach (var handler in handlers)
        {
            try
            {
                await handler(notification, cancellationToken).ConfigureAwait(false);
            }
            catch (AggregateException ex)
            {
                exceptions.AddRange(ex.Flatten().InnerExceptions);
            }
            catch (Exception ex) when (!(ex is OutOfMemoryException || ex is StackOverflowException))
            {
                exceptions.Add(ex);
            }
        }

        if (exceptions.Any())
        {
            throw new AggregateException(exceptions);
        }
    }

    private static async Task AsyncSequentialStopOnException(
        IEnumerable<Func<INotification, CancellationToken, Task>> handlers,
        INotification notification,
        CancellationToken cancellationToken)
    {
        foreach (var handler in handlers)
        {
            await handler(notification, cancellationToken).ConfigureAwait(false);
        }
    }
}

Registration and Extensions

We’ll define few extension methods to help the registration and create an overload for the Publish method.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
public static class MediatorExtensions
{
    public static Task Publish<TNotification>(this IMediator mediator, TNotification notification, PublishStrategy strategy, CancellationToken cancellationToken)
        where TNotification : INotification
    {
        if (mediator is ExtendedMediator customMediator)
        {
            return customMediator.Publish(notification, strategy, cancellationToken);
        }

        throw new NotSupportedException("The custom mediator implementation is not registered!");
    }

    public static IServiceCollection AddExtendedMediatR(this IServiceCollection services, params Assembly[] assemblies)
    {
        services.AddMediatR(options => options.Using<ExtendedMediator>().AsScoped(), assemblies);
        return services;
    }

    public static IServiceCollection AddExtendedMediatR(this IServiceCollection services, params Type[] handlerAssemblyMarkerTypes)
    {
        services.AddMediatR(options => options.Using<ExtendedMediator>().AsScoped(), handlerAssemblyMarkerTypes);

        return services;
    }
}

Usage

Now that we’re all set, the usage is quite straightforward. Register the extended mediator.

1
builder.Services.AddExtendedMediatR(typeof(Program));

Use it using existing IMediator contract.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
[ApiController]
public class DummyController : ControllerBase
{
    private readonly IMediator _mediator;

    public DummyController(IMediator mediator)
    {
        _mediator = mediator;
    }

    [HttpGet("/")]
    public async Task<ActionResult> Get(CancellationToken cancellationToken)
    {
        await _mediator.Publish(new Ping(), PublishStrategy.AsyncNoWait, cancellationToken);
        return Ok();
    }
}

I hope you found this article useful. Happy coding!

This post is licensed under CC BY 4.0 by the author.

Comments powered by Disqus.