diff --git a/src/core/Akka.Tests/Util/Internal/Collections/ImmutableAvlTreeTests.cs b/src/core/Akka.Tests/Util/Internal/Collections/ImmutableAvlTreeTests.cs index 44a9ff549a9..5341637c30b 100644 --- a/src/core/Akka.Tests/Util/Internal/Collections/ImmutableAvlTreeTests.cs +++ b/src/core/Akka.Tests/Util/Internal/Collections/ImmutableAvlTreeTests.cs @@ -5,6 +5,7 @@ // //----------------------------------------------------------------------- +using System; using System.Linq; using Akka.TestKit; using Akka.Util.Internal.Collections; @@ -169,6 +170,83 @@ public void WhenRemovingNodeThatCausesUnbalanceTheTreeIsRebalanced() } + private static int TreeImbalance(ImmutableAvlTree tree) + { + var lh = NodeHeight(tree.Root.Left); + var rh = NodeHeight(tree.Root.Right); + return lh - rh; + } + + private static int NodeHeight(IBinaryTreeNode tree) + { + if (tree == null) return 0; + return Math.Max(NodeHeight(tree.Left), NodeHeight(tree.Right)) + 1; + } + + [Fact] + public void WhenRemovingNodeThatCausesLeftThenRightHeavyUnbalanceTheTreeIsRebalanced() + { + var tree = ImmutableAvlTree.Empty + .Add(60, 6) + .Add(20, 2) + .Add(80, 8) + .Add(10, 1) + .Add(70, 7) + .Add(40, 4) + .Add(30, 3) + .Add(50, 5); + Assert.Equal(60, tree.Root.Key); + Assert.Equal(20, tree.Root.Left.Key); + Assert.Equal(10, tree.Root.Left.Left.Key); + Assert.Equal(40, tree.Root.Left.Right.Key); + Assert.Equal(80, tree.Root.Right.Key); + Assert.Equal(70, tree.Root.Right.Left.Key); + Assert.Null(tree.Root.Right.Right); + tree = tree.Remove(70); + + Assert.InRange(TreeImbalance(tree), -1, 1); + + Assert.Equal(40, tree.Root.Key); + Assert.Equal(20, tree.Root.Left.Key); + Assert.Equal(60, tree.Root.Right.Key); + Assert.Equal(10, tree.Root.Left.Left.Key); + Assert.Equal(30, tree.Root.Left.Right.Key); + Assert.Equal(50, tree.Root.Right.Left.Key); + Assert.Equal(80, tree.Root.Right.Right.Key); + } + + [Fact] + public void WhenRemovingNodeThatCausesRightThenLeftHeavyUnbalanceTheTreeIsRebalanced() + { + var tree = ImmutableAvlTree.Empty + .Add(-60, 6) + .Add(-20, 2) + .Add(-80, 8) + .Add(-10, 1) + .Add(-70, 7) + .Add(-40, 4) + .Add(-30, 3) + .Add(-50, 5); + Assert.Equal(-60, tree.Root.Key); + Assert.Equal(-20, tree.Root.Right.Key); + Assert.Equal(-10, tree.Root.Right.Right.Key); + Assert.Equal(-40, tree.Root.Right.Left.Key); + Assert.Equal(-80, tree.Root.Left.Key); + Assert.Equal(-70, tree.Root.Left.Right.Key); + Assert.Null(tree.Root.Left.Left); + tree = tree.Remove(-70); + + Assert.InRange(TreeImbalance(tree), -1, 1); + + Assert.Equal(-40, tree.Root.Key); + Assert.Equal(-20, tree.Root.Right.Key); + Assert.Equal(-60, tree.Root.Left.Key); + Assert.Equal(-10, tree.Root.Right.Right.Key); + Assert.Equal(-30, tree.Root.Right.Left.Key); + Assert.Equal(-50, tree.Root.Left.Right.Key); + Assert.Equal(-80, tree.Root.Left.Left.Key); + } + [Fact] public void WhenRemovingNodeThatDoNotExistSameTreeIsReturned() { diff --git a/src/core/Akka/Util/Internal/Collections/ImmutableAvlTreeBase.cs b/src/core/Akka/Util/Internal/Collections/ImmutableAvlTreeBase.cs index 244c395d204..99f80200b0b 100644 --- a/src/core/Akka/Util/Internal/Collections/ImmutableAvlTreeBase.cs +++ b/src/core/Akka/Util/Internal/Collections/ImmutableAvlTreeBase.cs @@ -439,11 +439,11 @@ private bool TryRemove(TKey key, Node node, out Node newNode) private Node MakeBalanced(Node tree) { Node result; - if (IsRightHeavy(tree)) + if (IsRightHeavyAndUnbalanced(tree)) { result = IsLeftHeavy(tree.Right) ? DoubleLeftRotation(tree) : RotateLeft(tree); } - else if (IsLeftHeavy(tree)) + else if (IsLeftHeavyAndUnbalanced(tree)) { result = IsRightHeavy(tree.Left) ? DoubleRightRotation(tree) : RotateRight(tree); } @@ -452,8 +452,10 @@ private Node MakeBalanced(Node tree) return result; } - private bool IsRightHeavy(Node tree) { return Balance(tree) >= 2; } - private bool IsLeftHeavy(Node tree) { return Balance(tree) <= -2; } + private bool IsRightHeavyAndUnbalanced(Node tree) { return Balance(tree) >= 2; } + private bool IsLeftHeavyAndUnbalanced(Node tree) { return Balance(tree) <= -2; } + private bool IsRightHeavy(Node tree) { return Balance(tree) >= 1; } + private bool IsLeftHeavy(Node tree) { return Balance(tree) <= -1; } private int Balance(Node tree) {