Selaa lähdekoodia

Merge pull request #8492 from AvaloniaUI/fixes/8480-classes-reentrancy

Fix reentrancy problem in Classes listeners.
Max Katz 3 vuotta sitten
vanhempi
sitoutus
803e6dc9d0

+ 0 - 0
.ncrunch/ControlCatalog.v3.ncrunchproject → .ncrunch/ControlCatalog.net6.0.v3.ncrunchproject


+ 5 - 0
.ncrunch/ControlCatalog.netstandard2.0.v3.ncrunchproject

@@ -0,0 +1,5 @@
+<ProjectConfiguration>
+  <Settings>
+    <IgnoreThisComponentCompletely>True</IgnoreThisComponentCompletely>
+  </Settings>
+</ProjectConfiguration>

+ 2 - 3
src/Avalonia.Base/Controls/Classes.cs

@@ -1,8 +1,7 @@
 using System;
 using System.Collections.Generic;
 using Avalonia.Collections;
-
-#nullable enable
+using Avalonia.Utilities;
 
 namespace Avalonia.Controls
 {
@@ -14,7 +13,7 @@ namespace Avalonia.Controls
     /// </remarks>
     public class Classes : AvaloniaList<string>, IPseudoClasses
     {
-        private List<IClassesChangedListener>? _listeners;
+        private SafeEnumerableList<IClassesChangedListener>? _listeners;
 
         /// <summary>
         /// Initializes a new instance of the <see cref="Classes"/> class.

+ 89 - 0
src/Avalonia.Base/Utilities/SafeEnumerableList.cs

@@ -0,0 +1,89 @@
+using System.Collections;
+using System.Collections.Generic;
+
+namespace Avalonia.Utilities
+{
+    /// <summary>
+    /// Implements a simple list which is safe to modify during enumeration.
+    /// </summary>
+    /// <typeparam name="T">The item type.</typeparam>
+    /// <remarks>
+    /// Implements a list which, when written to while enumerating, performs a copy of the list
+    /// items. Note this this class doesn't actually implement <see cref="IList{T}"/> as it's not
+    /// currently needed - feel free to add missing methods etc.
+    /// </remarks>
+    internal class SafeEnumerableList<T> : IEnumerable<T>
+    {
+        private List<T> _list = new();
+        private int _generation;
+        private int _enumCount = 0;
+
+        public int Count => _list.Count;
+        internal List<T> Inner => _list;
+
+        public void Add(T item) => GetList().Add(item);
+        public bool Remove(T item) => GetList().Remove(item);
+
+        public Enumerator GetEnumerator() => new(this, _list);
+        IEnumerator<T> IEnumerable<T>.GetEnumerator() => GetEnumerator();
+        IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
+        
+        private List<T> GetList()
+        {
+            if (_enumCount > 0)
+            {
+                _list = new(_list);
+                ++_generation;
+                _enumCount = 0;
+            }
+
+            return _list;
+        }
+
+        public struct Enumerator : IEnumerator<T>, IEnumerator
+        {
+            private readonly SafeEnumerableList<T> _owner;
+            private readonly List<T> _list;
+            private readonly int _generation;
+            private int _index;
+            private T? _current;
+
+            internal Enumerator(SafeEnumerableList<T> owner, List<T> list)
+            {
+                _owner = owner;
+                _list = list;
+                _generation = owner._generation;
+                _index = 0;
+                _current = default;
+                ++_owner._enumCount;
+            }
+
+            public void Dispose()
+            {
+                if (_owner._generation == _generation)
+                    --_owner._enumCount;
+            }
+
+            public bool MoveNext()
+            {
+                if (_index < _list.Count)
+                {
+                    _current = _list[_index++];
+                    return true;
+                }
+
+                _current = default;
+                return false;
+            }
+
+            public T Current => _current!;
+            object? IEnumerator.Current => _current;
+
+            void IEnumerator.Reset()
+            {
+                _index = 0;
+                _current = default;
+            }
+        }
+    }
+}

+ 31 - 0
tests/Avalonia.Controls.UnitTests/ClassesTests.cs

@@ -168,5 +168,36 @@ namespace Avalonia.Controls.UnitTests
 
             Assert.Equal(new[] { "foo" }, target);
         }
+
+        [Fact]
+        public void Listeners_Can_Be_Added_By_Listener()
+        {
+            var classes = new Classes();
+            var listener1 = new ClassesChangedListener(() => { });
+            var listener2 = new ClassesChangedListener(() => classes.AddListener(listener1));
+
+            classes.AddListener(listener2);
+            classes.Add("bar");
+        }
+
+        [Fact]
+        public void Listeners_Can_Be_Removed_By_Listener()
+        {
+            var classes = new Classes();
+            var listener1 = new ClassesChangedListener(() => { });
+            var listener2 = new ClassesChangedListener(() => classes.RemoveListener(listener1));
+
+            classes.AddListener(listener1);
+            classes.AddListener(listener2);
+            classes.Add("bar");
+        }
+
+        private class ClassesChangedListener : IClassesChangedListener
+        {
+            private Action _action;
+
+            public ClassesChangedListener(Action action) => _action = action;
+            public void Changed() => _action();
+        }
     }
 }

+ 130 - 0
tests/Avalonia.Controls.UnitTests/Utils/SafeEnumerableListTests.cs

@@ -0,0 +1,130 @@
+using System.Collections.Generic;
+using Avalonia.Utilities;
+using Xunit;
+
+namespace Avalonia.Controls.UnitTests.Utils
+{
+    public class SafeEnumerableListTests
+    {
+        [Fact]
+        public void List_Is_Not_Copied_Outside_Enumeration()
+        {
+            var target = new SafeEnumerableList<string>();
+            var inner = target.Inner;
+
+            target.Add("foo");
+            target.Add("bar");
+            target.Remove("foo");
+
+            Assert.Same(inner, target.Inner);
+        }
+
+        [Fact]
+        public void List_Is_Copied_Outside_Enumeration()
+        {
+            var target = new SafeEnumerableList<string>();
+            var inner = target.Inner;
+
+            target.Add("foo");
+
+            foreach (var i in target)
+            {
+                Assert.Same(inner, target.Inner);
+                target.Add("bar");
+                Assert.NotSame(inner, target.Inner);
+                Assert.Equal("foo", i);
+            }
+
+            inner = target.Inner;
+
+            foreach (var i in target)
+            {
+                target.Add("baz");
+                Assert.NotSame(inner, target.Inner);
+            }
+
+            Assert.Equal(new[] { "foo", "bar", "baz", "baz" }, target);
+        }
+
+        [Fact]
+        public void List_Is_Not_Copied_After_Enumeration()
+        {
+            var target = new SafeEnumerableList<string>();
+            var inner = target.Inner;
+
+            target.Add("foo");
+
+            foreach (var i in target)
+            {
+                target.Add("bar");
+                Assert.NotSame(inner, target.Inner);
+                inner = target.Inner;
+                Assert.Equal("foo", i);
+            }
+
+            target.Add("baz");
+            Assert.Same(inner, target.Inner);
+        }
+
+        [Fact]
+        public void List_Is_Copied_Only_Once_During_Enumeration()
+        {
+            var target = new SafeEnumerableList<string>();
+            var inner = target.Inner;
+
+            target.Add("foo");
+
+            foreach (var i in target)
+            {
+                target.Add("bar");
+                Assert.NotSame(inner, target.Inner);
+                inner = target.Inner;
+                target.Add("baz");
+                Assert.Same(inner, target.Inner);
+            }
+
+            target.Add("baz");
+        }
+
+        [Fact]
+        public void List_Is_Copied_During_Nested_Enumerations()
+        {
+            var target = new SafeEnumerableList<string>();
+            var initialInner = target.Inner;
+            var firstItems = new List<string>();
+            var secondItems = new List<string>();
+            List<string> firstInner;
+            List<string> secondInner;
+
+            target.Add("foo");
+
+            foreach (var i in target)
+            {
+                target.Add("bar");
+
+                firstInner = target.Inner;
+                Assert.NotSame(initialInner, firstInner);
+
+                foreach (var j in target)
+                {
+                    target.Add("baz");
+
+                    secondInner = target.Inner;
+                    Assert.NotSame(firstInner, secondInner);
+
+                    secondItems.Add(j);
+                }
+
+                firstItems.Add(i);
+            }
+
+            Assert.Equal(new[] { "foo" }, firstItems);
+            Assert.Equal(new[] { "foo", "bar" }, secondItems);
+            Assert.Equal(new[] { "foo", "bar", "baz", "baz" }, target);
+
+            var finalInner = target.Inner;
+            target.Add("final");
+            Assert.Same(finalInner, target.Inner);
+        }
+    }
+}