diff --git a/CSharpMath.Tests/Generic/BiDictionaryTests.cs b/CSharpMath.Tests/Generic/BiDictionaryTests.cs new file mode 100644 index 000000000..8ab71efb5 --- /dev/null +++ b/CSharpMath.Tests/Generic/BiDictionaryTests.cs @@ -0,0 +1,73 @@ +using Xunit; + +namespace CSharpMath.Tests.Generic { + public class BiDictionaryTests { + [Fact] + public void TestRemove() { + var testBiDictionary = new BiDictionary { + { 0, "0" }, + { 1, "1" }, + { 2, "8" }, + { 3, "10" } + }; + Assert.Equal(4, testBiDictionary.Firsts.Count); + Assert.Equal(4, testBiDictionary.Seconds.Count); + + Assert.True(testBiDictionary.Remove(2, "8")); + Assert.False(testBiDictionary.ContainsByFirst(2)); + Assert.False(testBiDictionary.ContainsBySecond("8")); + Assert.Equal(3, testBiDictionary.Firsts.Count); + Assert.Equal(3, testBiDictionary.Seconds.Count); + + // Remove with wrong first key + Assert.False(testBiDictionary.Remove(4, "10")); + Assert.False(testBiDictionary.ContainsByFirst(4)); + Assert.True(testBiDictionary.ContainsBySecond("10")); + Assert.Equal(3, testBiDictionary.Firsts.Count); + Assert.Equal(3, testBiDictionary.Seconds.Count); + + // Remove with wrong second key + Assert.False(testBiDictionary.Remove(3, "15")); + Assert.True(testBiDictionary.ContainsByFirst(3)); + Assert.False(testBiDictionary.ContainsBySecond("15")); + Assert.Equal(3, testBiDictionary.Firsts.Count); + Assert.Equal(3, testBiDictionary.Seconds.Count); + + // Remove when both exists but not corresponding to each other + Assert.True(testBiDictionary.Remove(0, "1")); + Assert.False(testBiDictionary.ContainsByFirst(0)); + Assert.False(testBiDictionary.ContainsBySecond("1")); + Assert.Single(testBiDictionary.Firsts); + Assert.Single(testBiDictionary.Seconds); + } + + [Fact] + public void TestAddOrReplace() { + var testBiDictionary = new BiDictionary(); + + testBiDictionary.AddOrReplace(0, "Value1"); + Assert.Equal("Value1", testBiDictionary[0]); + Assert.Equal(0, testBiDictionary["Value1"]); + + testBiDictionary.AddOrReplace(2, "Value10"); + Assert.Equal("Value10", testBiDictionary[2]); + Assert.Equal(2, testBiDictionary["Value10"]); + + testBiDictionary.AddOrReplace(2, "Value2"); + Assert.Equal("Value2", testBiDictionary[2]); + Assert.Equal(2, testBiDictionary["Value2"]); + Assert.Equal(2, testBiDictionary.Firsts.Count); + Assert.Equal(2, testBiDictionary.Seconds.Count); + + testBiDictionary.AddOrReplace(3, "Value3"); + Assert.Equal("Value3", testBiDictionary[3]); + Assert.Equal(3, testBiDictionary["Value3"]); + + testBiDictionary.AddOrReplace(10, "Value3"); + Assert.Equal("Value3", testBiDictionary[10]); + Assert.Equal(10, testBiDictionary["Value3"]); + Assert.Equal(3, testBiDictionary.Firsts.Count); + Assert.Equal(3, testBiDictionary.Seconds.Count); + } + } +} diff --git a/CSharpMath/Generic/BiDictionary.cs b/CSharpMath/Generic/BiDictionary.cs index 61c671a61..39f3fc05a 100644 --- a/CSharpMath/Generic/BiDictionary.cs +++ b/CSharpMath/Generic/BiDictionary.cs @@ -203,6 +203,15 @@ public void Add(TFirst first, TSecond second) { secondToFirst.Add(second, first); } + public void AddOrReplace(TFirst first, TSecond second) { + if (firstToSecond.ContainsKey(first)) + RemoveByFirst(first); + if (secondToFirst.ContainsKey(second)) + RemoveBySecond(second); + firstToSecond.Add(first, second); + secondToFirst.Add(second, first); + } + public Dictionary.KeyCollection Firsts => firstToSecond.Keys; public Dictionary.KeyCollection Seconds => secondToFirst.Keys; @@ -229,14 +238,14 @@ IEnumerator> IEnumerable item) => Add(item.Key, item.Value); - + public void AddOrReplace(KeyValuePair item) => AddOrReplace(item.Key, item.Value); public void Clear() { firstToSecond.Clear(); secondToFirst.Clear(); } - public bool Contains(TFirst first) => firstToSecond.ContainsKey(first); - public bool Contains(TSecond second) => secondToFirst.ContainsKey(second); + public bool ContainsByFirst(TFirst first) => firstToSecond.ContainsKey(first); + public bool ContainsBySecond(TSecond second) => secondToFirst.ContainsKey(second); public bool Contains(KeyValuePair pair) => firstToSecond.TryGetValue(pair.Key, out var second) && EqualityComparer.Default.Equals(second, pair.Value); @@ -245,12 +254,22 @@ public void CopyTo(KeyValuePair[] array, int arrayIndex) { array[arrayIndex++] = pair; } - public bool Remove(TFirst first, TSecond second) => - firstToSecond.Remove(first) && secondToFirst.Remove(second); - public bool Remove(KeyValuePair pair) => - firstToSecond.Remove(pair.Key) && secondToFirst.Remove(pair.Value); - public bool RemoveFirst(TFirst first) => Remove(first, firstToSecond[first]); - public bool RemoveSecond(TSecond second) => Remove(secondToFirst[second], second); + public bool Remove(TFirst first, TSecond second) { + if (TryGetByFirst(first, out var svalue) && TryGetBySecond(second, out var fvalue)) { + + firstToSecond.Remove(first); + firstToSecond.Remove(fvalue); + + secondToFirst.Remove(second); + secondToFirst.Remove(svalue); + return true; + } + return false; + } + + public bool Remove(KeyValuePair pair) => Remove(pair.Key, pair.Value); + public bool RemoveByFirst(TFirst first) => Remove(first, firstToSecond[first]); + public bool RemoveBySecond(TSecond second) => Remove(secondToFirst[second], second); } public class MultiDictionary : IEnumerable> {