Browse Source

Caching.StackExchangeRedis: add "force reconnect" pattern (#45261)

* 1. add "force reconnect" pattern via OnRedisError
2. avoid thread-race problems by having Connect[Async] *return* the cache, rather than relying on field
3. avoid thread-race problems by only using a single ref field (cache), rather than cache+muxer; access indirectly
4. use more efficient key-prefix path via RedisKey.Append
5. use more efficient hash-fetch by pre-storing the hash-key arrays (removes need for RedisExtensions.cs)
6. general code tidy; null-handling, defaults, etc (no semantic change)

* copy the time-based force reconnect logic from azure redis cache best practices

* remove a couple of dead comments

* PR feedback

* clarify timing

* Update src/Caching/StackExchangeRedis/src/RedisCache.cs

Co-authored-by: Stephen Halter <[email protected]>

* Update src/Caching/StackExchangeRedis/src/RedisCache.cs

Co-authored-by: Stephen Halter <[email protected]>

* 1. nits
2. add new UseForceReconnect option

* rev SE.Redis to 2.6.90, which has improvements to connection stability

* remove UseForceReconnect, preferring an app-context switch

* PR feedback; move the AppContext check to RedisCacheOptions

* - fixup NRT warnings (also uses ROM-byte support by default now)
- fixup interface delta in test harness
- make an analyzer happy about a redundant lambda

* Update src/Caching/StackExchangeRedis/src/RedisCache.cs

Co-authored-by: David Fowler <[email protected]>

* fixup PR drift

* fix NRT

---------

Co-authored-by: Stephen Halter <[email protected]>
Co-authored-by: Sébastien Ros <[email protected]>
Co-authored-by: David Fowler <[email protected]>
Marc Gravell 3 years ago
parent
commit
9fbc7c6b46

+ 1 - 1
eng/Versions.props

@@ -288,7 +288,7 @@
     <SeleniumWebDriverVersion>4.7.0</SeleniumWebDriverVersion>
     <SerilogExtensionsLoggingVersion>1.4.0</SerilogExtensionsLoggingVersion>
     <SerilogSinksFileVersion>4.0.0</SerilogSinksFileVersion>
-    <StackExchangeRedisVersion>2.2.4</StackExchangeRedisVersion>
+    <StackExchangeRedisVersion>2.6.90</StackExchangeRedisVersion>
     <SystemReactiveLinqVersion>5.0.0</SystemReactiveLinqVersion>
     <SwashbuckleAspNetCoreVersion>6.4.0</SwashbuckleAspNetCoreVersion>
     <XunitAbstractionsVersion>2.0.3</XunitAbstractionsVersion>

+ 251 - 85
src/Caching/StackExchangeRedis/src/RedisCache.cs

@@ -4,6 +4,8 @@
 using System;
 using System.Diagnostics;
 using System.Diagnostics.CodeAnalysis;
+using System.Net.Sockets;
+using System.Text;
 using System.Threading;
 using System.Threading.Tasks;
 using Microsoft.AspNetCore.Shared;
@@ -20,6 +22,9 @@ namespace Microsoft.Extensions.Caching.StackExchangeRedis;
 /// </summary>
 public partial class RedisCache : IDistributedCache, IDisposable
 {
+    // Note that the "force reconnect" pattern as described https://learn.microsoft.com/en-us/azure/azure-cache-for-redis/cache-best-practices-connection#using-forcereconnect-with-stackexchangeredis
+    // can be enabled via the "Microsoft.AspNetCore.Caching.StackExchangeRedis.UseForceReconnect" app-context switch
+    //
     // -- Explanation of why two kinds of SetScript are used --
     // * Redis 2.0 had HSET key field value for setting individual hash fields,
     // and HMSET key field value [field value ...] for setting multiple hash fields (against the same key).
@@ -50,20 +55,54 @@ public partial class RedisCache : IDistributedCache, IDisposable
     private const string AbsoluteExpirationKey = "absexp";
     private const string SlidingExpirationKey = "sldexp";
     private const string DataKey = "data";
+
+    // combined keys - same hash keys fetched constantly; avoid allocating an array each time
+    private static readonly RedisValue[] _hashMembersAbsoluteExpirationSlidingExpirationData = new RedisValue[] { AbsoluteExpirationKey, SlidingExpirationKey, DataKey };
+    private static readonly RedisValue[] _hashMembersAbsoluteExpirationSlidingExpiration = new RedisValue[] { AbsoluteExpirationKey, SlidingExpirationKey };
+
+    private static RedisValue[] GetHashFields(bool getData) => getData
+        ? _hashMembersAbsoluteExpirationSlidingExpirationData
+        : _hashMembersAbsoluteExpirationSlidingExpiration;
+
     private const long NotPresent = -1;
     private static readonly Version ServerVersionWithExtendedSetCommand = new Version(4, 0, 0);
 
-    private volatile IConnectionMultiplexer? _connection;
-    private IDatabase? _cache;
+    private volatile IDatabase? _cache;
     private bool _disposed;
     private string _setScript = SetScript;
 
     private readonly RedisCacheOptions _options;
-    private readonly string _instance;
+    private readonly RedisKey _instancePrefix;
     private readonly ILogger _logger;
 
     private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(initialCount: 1, maxCount: 1);
 
+    private long _lastConnectTicks = DateTimeOffset.UtcNow.Ticks;
+    private long _firstErrorTimeTicks;
+    private long _previousErrorTimeTicks;
+
+    // StackExchange.Redis will also be trying to reconnect internally,
+    // so limit how often we recreate the ConnectionMultiplexer instance
+    // in an attempt to reconnect
+
+    // Never reconnect within 60 seconds of the last attempt to connect or reconnect.
+    private readonly TimeSpan ReconnectMinInterval = TimeSpan.FromSeconds(60);
+    // Only reconnect if errors have occurred for at least the last 30 seconds.
+    // This count resets if there are no errors for 30 seconds
+    private readonly TimeSpan ReconnectErrorThreshold = TimeSpan.FromSeconds(30);
+
+    private static DateTimeOffset ReadTimeTicks(ref long field)
+    {
+        var ticks = Volatile.Read(ref field); // avoid torn values
+        return ticks == 0 ? DateTimeOffset.MinValue : new DateTimeOffset(ticks, TimeSpan.Zero);
+    }
+
+    private static void WriteTimeTicks(ref long field, DateTimeOffset value)
+    {
+        var ticks = value == DateTimeOffset.MinValue ? 0L : value.UtcTicks;
+        Volatile.Write(ref field, ticks); // avoid torn values
+    }
+
     /// <summary>
     /// Initializes a new instance of <see cref="RedisCache"/>.
     /// </summary>
@@ -87,7 +126,14 @@ public partial class RedisCache : IDistributedCache, IDisposable
         _logger = logger;
 
         // This allows partitioning a single backend cache for use with multiple apps/services.
-        _instance = _options.InstanceName ?? string.Empty;
+        var instanceName = _options.InstanceName;
+        if (!string.IsNullOrEmpty(instanceName))
+        {
+            // SE.Redis allows efficient append of key-prefix scenarios, but we can help it
+            // avoid some work/allocations by forcing the key-prefix to be a byte[]; SE.Redis
+            // would do this itself anyway, using UTF8
+            _instancePrefix = (RedisKey)Encoding.UTF8.GetBytes(instanceName);
+        }
     }
 
     /// <inheritdoc />
@@ -99,7 +145,7 @@ public partial class RedisCache : IDistributedCache, IDisposable
     }
 
     /// <inheritdoc />
-    public async Task<byte[]?> GetAsync(string key, CancellationToken token = default(CancellationToken))
+    public async Task<byte[]?> GetAsync(string key, CancellationToken token = default)
     {
         ArgumentNullThrowHelper.ThrowIfNull(key);
 
@@ -115,24 +161,32 @@ public partial class RedisCache : IDistributedCache, IDisposable
         ArgumentNullThrowHelper.ThrowIfNull(value);
         ArgumentNullThrowHelper.ThrowIfNull(options);
 
-        Connect();
+        var cache = Connect();
 
         var creationTime = DateTimeOffset.UtcNow;
 
         var absoluteExpiration = GetAbsoluteExpiration(creationTime, options);
 
-        _cache.ScriptEvaluate(_setScript, new RedisKey[] { _instance + key },
-            new RedisValue[]
-            {
+        try
+        {
+            cache.ScriptEvaluate(_setScript, new RedisKey[] { _instancePrefix.Append(key) },
+                new RedisValue[]
+                {
                         absoluteExpiration?.Ticks ?? NotPresent,
                         options.SlidingExpiration?.Ticks ?? NotPresent,
                         GetExpirationInSeconds(creationTime, absoluteExpiration, options) ?? NotPresent,
                         value
-            });
+                });
+        }
+        catch (Exception ex)
+        {
+            OnRedisError(ex, cache);
+            throw;
+        }
     }
 
     /// <inheritdoc />
-    public async Task SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token = default(CancellationToken))
+    public async Task SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token = default)
     {
         ArgumentNullThrowHelper.ThrowIfNull(key);
         ArgumentNullThrowHelper.ThrowIfNull(value);
@@ -140,21 +194,29 @@ public partial class RedisCache : IDistributedCache, IDisposable
 
         token.ThrowIfCancellationRequested();
 
-        await ConnectAsync(token).ConfigureAwait(false);
-        Debug.Assert(_cache is not null);
+        var cache = await ConnectAsync(token).ConfigureAwait(false);
+        Debug.Assert(cache is not null);
 
         var creationTime = DateTimeOffset.UtcNow;
 
         var absoluteExpiration = GetAbsoluteExpiration(creationTime, options);
 
-        await _cache.ScriptEvaluateAsync(_setScript, new RedisKey[] { _instance + key },
-            new RedisValue[]
-            {
+        try
+        {
+            await cache.ScriptEvaluateAsync(_setScript, new RedisKey[] { _instancePrefix.Append(key) },
+                new RedisValue[]
+                {
                 absoluteExpiration?.Ticks ?? NotPresent,
                 options.SlidingExpiration?.Ticks ?? NotPresent,
                 GetExpirationInSeconds(creationTime, absoluteExpiration, options) ?? NotPresent,
                 value
-            }).ConfigureAwait(false);
+                }).ConfigureAwait(false);
+        }
+        catch (Exception ex)
+        {
+            OnRedisError(ex, cache);
+            throw;
+        }
     }
 
     /// <inheritdoc />
@@ -166,7 +228,7 @@ public partial class RedisCache : IDistributedCache, IDisposable
     }
 
     /// <inheritdoc />
-    public async Task RefreshAsync(string key, CancellationToken token = default(CancellationToken))
+    public async Task RefreshAsync(string key, CancellationToken token = default)
     {
         ArgumentNullThrowHelper.ThrowIfNull(key);
 
@@ -175,84 +237,95 @@ public partial class RedisCache : IDistributedCache, IDisposable
         await GetAndRefreshAsync(key, getData: false, token: token).ConfigureAwait(false);
     }
 
-    [MemberNotNull(nameof(_cache), nameof(_connection))]
-    private void Connect()
+    [MemberNotNull(nameof(_cache))]
+    private IDatabase Connect()
     {
         CheckDisposed();
-        if (_cache != null)
+        var cache = _cache;
+        if (cache is not null)
         {
-            Debug.Assert(_connection != null);
-            return;
+            Debug.Assert(_cache is not null);
+            return cache;
         }
 
         _connectionLock.Wait();
         try
         {
-            if (_cache == null)
+            cache = _cache;
+            if (cache is null)
             {
-                if (_options.ConnectionMultiplexerFactory == null)
+                IConnectionMultiplexer connection;
+                if (_options.ConnectionMultiplexerFactory is null)
                 {
                     if (_options.ConfigurationOptions is not null)
                     {
-                        _connection = ConnectionMultiplexer.Connect(_options.ConfigurationOptions);
+                        connection = ConnectionMultiplexer.Connect(_options.ConfigurationOptions);
                     }
                     else
                     {
-                        _connection = ConnectionMultiplexer.Connect(_options.Configuration);
+                        connection = ConnectionMultiplexer.Connect(_options.Configuration!);
                     }
                 }
                 else
                 {
-                    _connection = _options.ConnectionMultiplexerFactory().GetAwaiter().GetResult();
+                    connection = _options.ConnectionMultiplexerFactory().GetAwaiter().GetResult();
                 }
 
-                PrepareConnection();
-                _cache = _connection.GetDatabase();
+                PrepareConnection(connection);
+                cache = _cache = connection.GetDatabase();
             }
+            Debug.Assert(_cache is not null);
+            return cache;
         }
         finally
         {
             _connectionLock.Release();
         }
-
-        Debug.Assert(_connection != null);
     }
 
-    private async Task ConnectAsync(CancellationToken token = default(CancellationToken))
+    private ValueTask<IDatabase> ConnectAsync(CancellationToken token = default)
     {
         CheckDisposed();
         token.ThrowIfCancellationRequested();
 
-        if (_cache != null)
+        var cache = _cache;
+        if (cache is not null)
         {
-            Debug.Assert(_connection != null);
-            return;
+            Debug.Assert(_cache is not null);
+            return new(cache);
         }
-
+        return ConnectSlowAsync(token);
+    }
+    private async ValueTask<IDatabase> ConnectSlowAsync(CancellationToken token)
+    {
         await _connectionLock.WaitAsync(token).ConfigureAwait(false);
         try
         {
-            if (_cache == null)
+            var cache = _cache;
+            if (cache is null)
             {
+                IConnectionMultiplexer connection;
                 if (_options.ConnectionMultiplexerFactory is null)
                 {
                     if (_options.ConfigurationOptions is not null)
                     {
-                        _connection = await ConnectionMultiplexer.ConnectAsync(_options.ConfigurationOptions).ConfigureAwait(false);
+                        connection = await ConnectionMultiplexer.ConnectAsync(_options.ConfigurationOptions).ConfigureAwait(false);
                     }
                     else
                     {
-                        _connection = await ConnectionMultiplexer.ConnectAsync(_options.Configuration).ConfigureAwait(false);
+                        connection = await ConnectionMultiplexer.ConnectAsync(_options.Configuration!).ConfigureAwait(false);
                     }
                 }
                 else
                 {
-                    _connection = await _options.ConnectionMultiplexerFactory().ConfigureAwait(false);
+                    connection = await _options.ConnectionMultiplexerFactory().ConfigureAwait(false);
                 }
 
-                PrepareConnection();
-                _cache = _connection.GetDatabase();
+                PrepareConnection(connection);
+                cache = _cache = connection.GetDatabase();
             }
+            Debug.Assert(_cache is not null);
+            return cache;
         }
         finally
         {
@@ -260,21 +333,22 @@ public partial class RedisCache : IDistributedCache, IDisposable
         }
     }
 
-    private void PrepareConnection()
+    private void PrepareConnection(IConnectionMultiplexer connection)
     {
-        ValidateServerFeatures();
-        TryRegisterProfiler();
+        WriteTimeTicks(ref _lastConnectTicks, DateTimeOffset.UtcNow);
+        ValidateServerFeatures(connection);
+        TryRegisterProfiler(connection);
     }
 
-    private void ValidateServerFeatures()
+    private void ValidateServerFeatures(IConnectionMultiplexer connection)
     {
-        _ = _connection ?? throw new InvalidOperationException($"{nameof(_connection)} cannot be null.");
+        _ = connection ?? throw new InvalidOperationException($"{nameof(connection)} cannot be null.");
 
         try
         {
-            foreach (var endPoint in _connection.GetEndPoints())
+            foreach (var endPoint in connection.GetEndPoints())
             {
-                if (_connection.GetServer(endPoint).Version < ServerVersionWithExtendedSetCommand)
+                if (connection.GetServer(endPoint).Version < ServerVersionWithExtendedSetCommand)
                 {
                     _setScript = SetScriptPreExtendedSetCommand;
                     return;
@@ -291,13 +365,13 @@ public partial class RedisCache : IDistributedCache, IDisposable
         }
     }
 
-    private void TryRegisterProfiler()
+    private void TryRegisterProfiler(IConnectionMultiplexer connection)
     {
-        _ = _connection ?? throw new InvalidOperationException($"{nameof(_connection)} cannot be null.");
+        _ = connection ?? throw new InvalidOperationException($"{nameof(connection)} cannot be null.");
 
-        if (_options.ProfilingSession != null)
+        if (_options.ProfilingSession is not null)
         {
-            _connection.RegisterProfiler(_options.ProfilingSession);
+            connection.RegisterProfiler(_options.ProfilingSession);
         }
     }
 
@@ -305,25 +379,25 @@ public partial class RedisCache : IDistributedCache, IDisposable
     {
         ArgumentNullThrowHelper.ThrowIfNull(key);
 
-        Connect();
+        var cache = Connect();
 
         // This also resets the LRU status as desired.
         // TODO: Can this be done in one operation on the server side? Probably, the trick would just be the DateTimeOffset math.
         RedisValue[] results;
-        if (getData)
+        try
         {
-            results = _cache.HashMemberGet(_instance + key, AbsoluteExpirationKey, SlidingExpirationKey, DataKey);
+            results = cache.HashGet(_instancePrefix.Append(key), GetHashFields(getData));
         }
-        else
+        catch (Exception ex)
         {
-            results = _cache.HashMemberGet(_instance + key, AbsoluteExpirationKey, SlidingExpirationKey);
+            OnRedisError(ex, cache);
+            throw;
         }
 
-        // TODO: Error handling
         if (results.Length >= 2)
         {
             MapMetadata(results, out DateTimeOffset? absExpr, out TimeSpan? sldExpr);
-            Refresh(_cache, key, absExpr, sldExpr);
+            Refresh(cache, key, absExpr, sldExpr);
         }
 
         if (results.Length >= 3 && results[2].HasValue)
@@ -334,32 +408,32 @@ public partial class RedisCache : IDistributedCache, IDisposable
         return null;
     }
 
-    private async Task<byte[]?> GetAndRefreshAsync(string key, bool getData, CancellationToken token = default(CancellationToken))
+    private async Task<byte[]?> GetAndRefreshAsync(string key, bool getData, CancellationToken token = default)
     {
         ArgumentNullThrowHelper.ThrowIfNull(key);
 
         token.ThrowIfCancellationRequested();
 
-        await ConnectAsync(token).ConfigureAwait(false);
-        Debug.Assert(_cache is not null);
+        var cache = await ConnectAsync(token).ConfigureAwait(false);
+        Debug.Assert(cache is not null);
 
         // This also resets the LRU status as desired.
         // TODO: Can this be done in one operation on the server side? Probably, the trick would just be the DateTimeOffset math.
         RedisValue[] results;
-        if (getData)
+        try
         {
-            results = await _cache.HashMemberGetAsync(_instance + key, AbsoluteExpirationKey, SlidingExpirationKey, DataKey).ConfigureAwait(false);
+            results = await cache.HashGetAsync(_instancePrefix.Append(key), GetHashFields(getData)).ConfigureAwait(false);
         }
-        else
+        catch (Exception ex)
         {
-            results = await _cache.HashMemberGetAsync(_instance + key, AbsoluteExpirationKey, SlidingExpirationKey).ConfigureAwait(false);
+            OnRedisError(ex, cache);
+            throw;
         }
 
-        // TODO: Error handling
         if (results.Length >= 2)
         {
             MapMetadata(results, out DateTimeOffset? absExpr, out TimeSpan? sldExpr);
-            await RefreshAsync(_cache, key, absExpr, sldExpr, token).ConfigureAwait(false);
+            await RefreshAsync(cache, key, absExpr, sldExpr, token).ConfigureAwait(false);
         }
 
         if (results.Length >= 3 && results[2].HasValue)
@@ -375,22 +449,35 @@ public partial class RedisCache : IDistributedCache, IDisposable
     {
         ArgumentNullThrowHelper.ThrowIfNull(key);
 
-        Connect();
-
-        _cache.KeyDelete(_instance + key);
-        // TODO: Error handling
+        var cache = Connect();
+        try
+        {
+            cache.KeyDelete(_instancePrefix.Append(key));
+        }
+        catch (Exception ex)
+        {
+            OnRedisError(ex, cache);
+            throw;
+        }
     }
 
     /// <inheritdoc />
-    public async Task RemoveAsync(string key, CancellationToken token = default(CancellationToken))
+    public async Task RemoveAsync(string key, CancellationToken token = default)
     {
         ArgumentNullThrowHelper.ThrowIfNull(key);
 
-        await ConnectAsync(token).ConfigureAwait(false);
-        Debug.Assert(_cache is not null);
+        var cache = await ConnectAsync(token).ConfigureAwait(false);
+        Debug.Assert(cache is not null);
 
-        await _cache.KeyDeleteAsync(_instance + key).ConfigureAwait(false);
-        // TODO: Error handling
+        try
+        {
+            await cache.KeyDeleteAsync(_instancePrefix.Append(key)).ConfigureAwait(false);
+        }
+        catch (Exception ex)
+        {
+            OnRedisError(ex, cache);
+            throw;
+        }
     }
 
     private static void MapMetadata(RedisValue[] results, out DateTimeOffset? absoluteExpiration, out TimeSpan? slidingExpiration)
@@ -426,12 +513,19 @@ public partial class RedisCache : IDistributedCache, IDisposable
             {
                 expr = sldExpr;
             }
-            cache.KeyExpire(_instance + key, expr);
-            // TODO: Error handling
+            try
+            {
+                cache.KeyExpire(_instancePrefix.Append(key), expr);
+            }
+            catch (Exception ex)
+            {
+                OnRedisError(ex, cache);
+                throw;
+            }
         }
     }
 
-    private async Task RefreshAsync(IDatabase cache, string key, DateTimeOffset? absExpr, TimeSpan? sldExpr, CancellationToken token = default(CancellationToken))
+    private async Task RefreshAsync(IDatabase cache, string key, DateTimeOffset? absExpr, TimeSpan? sldExpr, CancellationToken token = default)
     {
         ArgumentNullThrowHelper.ThrowIfNull(key);
 
@@ -450,8 +544,15 @@ public partial class RedisCache : IDistributedCache, IDisposable
             {
                 expr = sldExpr;
             }
-            await cache.KeyExpireAsync(_instance + key, expr).ConfigureAwait(false);
-            // TODO: Error handling
+            try
+            {
+                await cache.KeyExpireAsync(_instancePrefix.Append(key), expr).ConfigureAwait(false);
+            }
+            catch (Exception ex)
+            {
+                OnRedisError(ex, cache);
+                throw;
+            }
         }
     }
 
@@ -501,11 +602,76 @@ public partial class RedisCache : IDistributedCache, IDisposable
         }
 
         _disposed = true;
-        _connection?.Close();
+        ReleaseConnection(Interlocked.Exchange(ref _cache, null));
+
     }
 
     private void CheckDisposed()
     {
         ObjectDisposedThrowHelper.ThrowIf(_disposed, this);
     }
+
+    private void OnRedisError(Exception exception, IDatabase cache)
+    {
+        if (_options.UseForceReconnect && (exception is RedisConnectionException or SocketException))
+        {
+            var utcNow = DateTimeOffset.UtcNow;
+            var previousConnectTime = ReadTimeTicks(ref _lastConnectTicks);
+            TimeSpan elapsedSinceLastReconnect = utcNow - previousConnectTime;
+
+            // We want to limit how often we perform this top-level reconnect, so we check how long it's been since our last attempt.
+            if (elapsedSinceLastReconnect < ReconnectMinInterval)
+            {
+                return;
+            }
+
+            var firstErrorTime = ReadTimeTicks(ref _firstErrorTimeTicks);
+            if (firstErrorTime == DateTimeOffset.MinValue)
+            {
+                // note: order/timing here (between the two fields) is not critical
+                WriteTimeTicks(ref _firstErrorTimeTicks, utcNow);
+                WriteTimeTicks(ref _previousErrorTimeTicks, utcNow);
+                return;
+            }
+
+            TimeSpan elapsedSinceFirstError = utcNow - firstErrorTime;
+            TimeSpan elapsedSinceMostRecentError = utcNow - ReadTimeTicks(ref _previousErrorTimeTicks);
+
+            bool shouldReconnect =
+                    elapsedSinceFirstError >= ReconnectErrorThreshold // Make sure we gave the multiplexer enough time to reconnect on its own if it could.
+                    && elapsedSinceMostRecentError <= ReconnectErrorThreshold; // Make sure we aren't working on stale data (e.g. if there was a gap in errors, don't reconnect yet).
+
+            // Update the previousErrorTime timestamp to be now (e.g. this reconnect request).
+            WriteTimeTicks(ref _previousErrorTimeTicks, utcNow);
+
+            if (!shouldReconnect)
+            {
+                return;
+            }
+
+            WriteTimeTicks(ref _firstErrorTimeTicks, DateTimeOffset.MinValue);
+            WriteTimeTicks(ref _previousErrorTimeTicks, DateTimeOffset.MinValue);
+
+            // wipe the shared field, but *only* if it is still the cache we were
+            // thinking about (once it is null, the next caller will reconnect)
+            ReleaseConnection(Interlocked.CompareExchange(ref _cache, null, cache));
+        }
+    }
+
+    static void ReleaseConnection(IDatabase? cache)
+    {
+        var connection = cache?.Multiplexer;
+        if (connection is not null)
+        {
+            try
+            {
+                connection.Close();
+                connection.Dispose();
+            }
+            catch (Exception ex)
+            {
+                Debug.WriteLine(ex);
+            }
+        }
+    }
 }

+ 12 - 0
src/Caching/StackExchangeRedis/src/RedisCacheOptions.cs

@@ -45,4 +45,16 @@ public class RedisCacheOptions : IOptions<RedisCacheOptions>
     {
         get { return this; }
     }
+
+    private bool? _useForceReconnect;
+    internal bool UseForceReconnect
+    {
+        get
+        {
+            return _useForceReconnect ??= GetDefaultValue();
+            static bool GetDefaultValue() =>
+                AppContext.TryGetSwitch("Microsoft.AspNetCore.Caching.StackExchangeRedis.UseForceReconnect", out var value) && value;
+        }
+        set => _useForceReconnect = value;
+    }
 }

+ 0 - 36
src/Caching/StackExchangeRedis/src/RedisExtensions.cs

@@ -1,36 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-using System.Threading.Tasks;
-using StackExchange.Redis;
-
-namespace Microsoft.Extensions.Caching.StackExchangeRedis;
-
-internal static class RedisExtensions
-{
-    internal static RedisValue[] HashMemberGet(this IDatabase cache, string key, params string[] members)
-    {
-        // TODO: Error checking?
-        return cache.HashGet(key, GetRedisMembers(members));
-    }
-
-    internal static async Task<RedisValue[]> HashMemberGetAsync(
-        this IDatabase cache,
-        string key,
-        params string[] members)
-    {
-        // TODO: Error checking?
-        return await cache.HashGetAsync(key, GetRedisMembers(members)).ConfigureAwait(false);
-    }
-
-    private static RedisValue[] GetRedisMembers(params string[] members)
-    {
-        var redisMembers = new RedisValue[members.Length];
-        for (int i = 0; i < members.Length; i++)
-        {
-            redisMembers[i] = (RedisValue)members[i];
-        }
-
-        return redisMembers;
-    }
-}

+ 2 - 2
src/DataProtection/StackExchangeRedis/src/RedisXmlRepository.cs

@@ -1,4 +1,4 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
@@ -45,7 +45,7 @@ public class RedisXmlRepository : IXmlRepository
         var database = _databaseFactory();
         foreach (var value in database.ListRange(_key))
         {
-            yield return XElement.Parse(value);
+            yield return XElement.Parse((string)value!);
         }
     }
 

+ 1 - 1
src/SignalR/server/StackExchangeRedis/src/Internal/RedisLog.cs

@@ -15,7 +15,7 @@ internal static partial class RedisLog
     {
         if (logger.IsEnabled(LogLevel.Information) && endpoints.Count > 0)
         {
-            ConnectingToEndpoints(logger, string.Join(", ", endpoints.Select(e => EndPointCollection.ToString(e))), serverName);
+            ConnectingToEndpoints(logger, string.Join(", ", endpoints.Select(EndPointCollection.ToString)), serverName);
         }
     }
 

+ 10 - 7
src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs

@@ -455,7 +455,7 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab
             {
                 RedisLog.ReceivedFromChannel(_logger, _channels.All);
 
-                var invocation = RedisProtocol.ReadInvocation((byte[])channelMessage.Message);
+                var invocation = RedisProtocol.ReadInvocation(channelMessage.Message);
 
                 var tasks = new List<Task>(_connections.Count);
 
@@ -483,7 +483,7 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab
         {
             try
             {
-                var groupMessage = RedisProtocol.ReadGroupCommand((byte[])channelMessage.Message);
+                var groupMessage = RedisProtocol.ReadGroupCommand(channelMessage.Message);
 
                 var connection = _connections[groupMessage.ConnectionId];
                 if (connection == null)
@@ -518,7 +518,7 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab
         var channel = await _bus!.SubscribeAsync(_channels.Ack(_serverName));
         channel.OnMessage(channelMessage =>
         {
-            var ackId = RedisProtocol.ReadAck((byte[])channelMessage.Message);
+            var ackId = RedisProtocol.ReadAck(channelMessage.Message);
 
             _ackHandler.TriggerAck(ackId);
         });
@@ -532,7 +532,7 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab
         var channel = await _bus!.SubscribeAsync(connectionChannel);
         channel.OnMessage(channelMessage =>
         {
-            var invocation = RedisProtocol.ReadInvocation((byte[])channelMessage.Message);
+            var invocation = RedisProtocol.ReadInvocation(channelMessage.Message);
 
             // This is a Client result we need to setup state for the completion and forward the message to the client
             if (!string.IsNullOrEmpty(invocation.InvocationId))
@@ -591,7 +591,7 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab
             {
                 try
                 {
-                    var invocation = RedisProtocol.ReadInvocation((byte[])channelMessage.Message);
+                    var invocation = RedisProtocol.ReadInvocation(channelMessage.Message);
 
                     var tasks = new List<Task>(subscriptions.Count);
                     foreach (var userConnection in subscriptions)
@@ -617,7 +617,7 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab
         {
             try
             {
-                var invocation = RedisProtocol.ReadInvocation((byte[])channelMessage.Message);
+                var invocation = RedisProtocol.ReadInvocation(channelMessage.Message);
 
                 var tasks = new List<Task>(groupConnections.Count);
                 foreach (var groupConnection in groupConnections)
@@ -778,7 +778,10 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab
                             return;
                         }
 
-                        RedisLog.ConnectionFailed(_logger, e.Exception);
+                        if (e.Exception is not null)
+                        {
+                            RedisLog.ConnectionFailed(_logger, e.Exception);
+                        }
                     };
 
                     if (_redisServerConnection.IsConnected)

+ 14 - 0
src/SignalR/server/StackExchangeRedis/test/TestConnectionMultiplexer.cs

@@ -10,6 +10,7 @@ using System.Reflection;
 using System.Threading;
 using System.Threading.Tasks;
 using StackExchange.Redis;
+using StackExchange.Redis.Maintenance;
 using StackExchange.Redis.Profiling;
 
 namespace Microsoft.AspNetCore.SignalR.Tests;
@@ -77,6 +78,12 @@ public class TestConnectionMultiplexer : IConnectionMultiplexer
 
     private readonly TestRedisServer _server;
 
+    public event EventHandler<ServerMaintenanceEvent> ServerMaintenanceEvent
+    {
+        add { }
+        remove { }
+    }
+
     public TestConnectionMultiplexer(TestRedisServer server)
     {
         _server = server;
@@ -221,6 +228,13 @@ public class TestConnectionMultiplexer : IConnectionMultiplexer
     {
         throw new NotImplementedException();
     }
+
+    public IServer[] GetServers()
+    {
+        throw new NotImplementedException();
+    }
+
+    public ValueTask DisposeAsync() => default;
 }
 
 public class TestRedisServer