Skip to content

Commit

Permalink
Try making {Try}GetAlternateLookup instance methods on Dictionary/Has…
Browse files Browse the repository at this point in the history
…hSet (#106107)

We made them extension methods instead of instance methods to avoid potential native code size bloat. But the ergonomics of using these without partial generic inference is a bit painful, and we've had reports that it makes them harder to understand. This moves them to be instance methods, and we'll measure the impact on code size to re-evaluate the decision.
  • Loading branch information
stephentoub committed Aug 8, 2024
1 parent ca8e63e commit 44b6b2a
Show file tree
Hide file tree
Showing 11 changed files with 359 additions and 389 deletions.
8 changes: 4 additions & 4 deletions src/libraries/System.Collections/ref/System.Collections.cs
Original file line number Diff line number Diff line change
Expand Up @@ -618,16 +618,12 @@ namespace System.Collections.Generic
public static partial class CollectionExtensions
{
public static void AddRange<T>(this System.Collections.Generic.List<T> list, params System.ReadOnlySpan<T> source) { }
public static System.Collections.Generic.Dictionary<TKey, TValue>.AlternateLookup<TAlternateKey> GetAlternateLookup<TKey, TValue, TAlternateKey>(this System.Collections.Generic.Dictionary<TKey, TValue> dictionary) where TKey : notnull where TAlternateKey : notnull, allows ref struct { throw null; }
public static System.Collections.Generic.HashSet<T>.AlternateLookup<TAlternate> GetAlternateLookup<T, TAlternate>(this System.Collections.Generic.HashSet<T> set) where TAlternate : allows ref struct { throw null; }
public static void CopyTo<T>(this System.Collections.Generic.List<T> list, System.Span<T> destination) { }
public static TValue? GetValueOrDefault<TKey, TValue>(this System.Collections.Generic.IReadOnlyDictionary<TKey, TValue> dictionary, TKey key) { throw null; }
public static TValue GetValueOrDefault<TKey, TValue>(this System.Collections.Generic.IReadOnlyDictionary<TKey, TValue> dictionary, TKey key, TValue defaultValue) { throw null; }
public static void InsertRange<T>(this System.Collections.Generic.List<T> list, int index, params System.ReadOnlySpan<T> source) { }
public static bool Remove<TKey, TValue>(this System.Collections.Generic.IDictionary<TKey, TValue> dictionary, TKey key, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TValue value) { throw null; }
public static bool TryAdd<TKey, TValue>(this System.Collections.Generic.IDictionary<TKey, TValue> dictionary, TKey key, TValue value) { throw null; }
public static bool TryGetAlternateLookup<TKey, TValue, TAlternateKey>(this System.Collections.Generic.Dictionary<TKey, TValue> dictionary, out System.Collections.Generic.Dictionary<TKey, TValue>.AlternateLookup<TAlternateKey> lookup) where TKey : notnull where TAlternateKey : notnull, allows ref struct { throw null; }
public static bool TryGetAlternateLookup<T, TAlternate>(this System.Collections.Generic.HashSet<T> set, out System.Collections.Generic.HashSet<T>.AlternateLookup<TAlternate> lookup) where TAlternate : allows ref struct { throw null; }
public static System.Collections.ObjectModel.ReadOnlyCollection<T> AsReadOnly<T>(this IList<T> list) { throw null; }
public static System.Collections.ObjectModel.ReadOnlyDictionary<TKey, TValue> AsReadOnly<TKey, TValue>(this IDictionary<TKey, TValue> dictionary) where TKey : notnull { throw null; }
}
Expand Down Expand Up @@ -675,6 +671,7 @@ public void Clear() { }
public bool ContainsKey(TKey key) { throw null; }
public bool ContainsValue(TValue value) { throw null; }
public int EnsureCapacity(int capacity) { throw null; }
public System.Collections.Generic.Dictionary<TKey, TValue>.AlternateLookup<TAlternateKey> GetAlternateLookup<TAlternateKey>() where TAlternateKey : notnull, allows ref struct { throw null; }
public System.Collections.Generic.Dictionary<TKey, TValue>.Enumerator GetEnumerator() { throw null; }
[System.ObsoleteAttribute("This API supports obsolete formatter-based serialization. It should not be called or extended by application code.", DiagnosticId = "SYSLIB0051", UrlFormat = "https://aka.ms/dotnet-warnings/{0}")]
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
Expand All @@ -696,6 +693,7 @@ void System.Collections.IDictionary.Remove(object key) { }
public void TrimExcess() { }
public void TrimExcess(int capacity) { }
public bool TryAdd(TKey key, TValue value) { throw null; }
public bool TryGetAlternateLookup<TAlternateKey>(out System.Collections.Generic.Dictionary<TKey, TValue>.AlternateLookup<TAlternateKey> lookup) where TAlternateKey : notnull, allows ref struct { throw null; }
public bool TryGetValue(TKey key, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TValue value) { throw null; }
public readonly partial struct AlternateLookup<TAlternateKey> where TAlternateKey : notnull, allows ref struct
{
Expand Down Expand Up @@ -812,6 +810,7 @@ public void CopyTo(T[] array, int arrayIndex, int count) { }
public static System.Collections.Generic.IEqualityComparer<System.Collections.Generic.HashSet<T>> CreateSetComparer() { throw null; }
public int EnsureCapacity(int capacity) { throw null; }
public void ExceptWith(System.Collections.Generic.IEnumerable<T> other) { }
public System.Collections.Generic.HashSet<T>.AlternateLookup<TAlternate> GetAlternateLookup<TAlternate>() where TAlternate : allows ref struct { throw null; }
public System.Collections.Generic.HashSet<T>.Enumerator GetEnumerator() { throw null; }
[System.ObsoleteAttribute("This API supports obsolete formatter-based serialization. It should not be called or extended by application code.", DiagnosticId = "SYSLIB0051", UrlFormat = "https://aka.ms/dotnet-warnings/{0}")]
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
Expand All @@ -832,6 +831,7 @@ void System.Collections.Generic.ICollection<T>.Add(T item) { }
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; }
public void TrimExcess() { }
public void TrimExcess(int capacity) { }
public bool TryGetAlternateLookup<TAlternate>(out System.Collections.Generic.HashSet<T>.AlternateLookup<TAlternate> lookup) where TAlternate : allows ref struct { throw null; }
public bool TryGetValue(T equalValue, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out T actualValue) { throw null; }
public void UnionWith(System.Collections.Generic.IEnumerable<T> other) { }
public readonly partial struct AlternateLookup<TAlternate> where TAlternate : allows ref struct
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,260 +141,14 @@ public void AsReadOnly_NullIDictionary_ThrowsArgumentNullException()
Assert.Throws<ArgumentNullException>("dictionary", () => dictionary.AsReadOnly());
}

[Fact]
public void GetAlternateLookup_ThrowsWhenNull()
{
AssertExtensions.Throws<ArgumentNullException>("dictionary", () => CollectionExtensions.GetAlternateLookup<int, int, long>((Dictionary<int, int>)null));
AssertExtensions.Throws<ArgumentNullException>("dictionary", () => CollectionExtensions.TryGetAlternateLookup<int, int, long>((Dictionary<int, int>)null, out _));

AssertExtensions.Throws<ArgumentNullException>("set", () => CollectionExtensions.GetAlternateLookup<int, long>((HashSet<int>)null));
AssertExtensions.Throws<ArgumentNullException>("set", () => CollectionExtensions.TryGetAlternateLookup<int, long>((HashSet<int>)null, out _));
}

[Fact]
public void GetAlternateLookup_FailsWhenIncompatible()
{
var dictionary = new Dictionary<string, string>(StringComparer.Ordinal);
var hashSet = new HashSet<string>(StringComparer.Ordinal);

dictionary.GetAlternateLookup<string, string, ReadOnlySpan<char>>();
Assert.True(dictionary.TryGetAlternateLookup<string, string, ReadOnlySpan<char>>(out _));

hashSet.GetAlternateLookup<string, ReadOnlySpan<char>>();
Assert.True(hashSet.TryGetAlternateLookup<string, ReadOnlySpan<char>>(out _));

Assert.Throws<InvalidOperationException>(() => dictionary.GetAlternateLookup<string, string, ReadOnlySpan<byte>>());
Assert.Throws<InvalidOperationException>(() => dictionary.GetAlternateLookup<string, string, string>());
Assert.Throws<InvalidOperationException>(() => dictionary.GetAlternateLookup<string, string, int>());

Assert.False(dictionary.TryGetAlternateLookup<string, string, ReadOnlySpan<byte>>(out _));
Assert.False(dictionary.TryGetAlternateLookup<string, string, string>(out _));
Assert.False(dictionary.TryGetAlternateLookup<string, string, int>(out _));

Assert.Throws<InvalidOperationException>(() => hashSet.GetAlternateLookup<string, ReadOnlySpan<byte>>());
Assert.Throws<InvalidOperationException>(() => hashSet.GetAlternateLookup<string, string>());
Assert.Throws<InvalidOperationException>(() => hashSet.GetAlternateLookup<string, int>());

Assert.False(hashSet.TryGetAlternateLookup<string, ReadOnlySpan<byte>>(out _));
Assert.False(hashSet.TryGetAlternateLookup<string, string>(out _));
Assert.False(hashSet.TryGetAlternateLookup<string, int>(out _));
}

public static IEnumerable<object[]> Dictionary_GetAlternateLookup_OperationsMatchUnderlyingDictionary_MemberData()
{
yield return new object[] { EqualityComparer<string>.Default };
yield return new object[] { StringComparer.Ordinal };
yield return new object[] { StringComparer.OrdinalIgnoreCase };
yield return new object[] { StringComparer.InvariantCulture };
yield return new object[] { StringComparer.InvariantCultureIgnoreCase };
yield return new object[] { StringComparer.CurrentCulture };
yield return new object[] { StringComparer.CurrentCultureIgnoreCase };
}

[Theory]
[MemberData(nameof(Dictionary_GetAlternateLookup_OperationsMatchUnderlyingDictionary_MemberData))]
public void Dictionary_GetAlternateLookup_OperationsMatchUnderlyingDictionary(IEqualityComparer<string> comparer)
{
// Test with a variety of comparers to ensure that the alternate lookup is consistent with the underlying dictionary
Dictionary<string, int> dictionary = new(comparer);
Dictionary<string, int>.AlternateLookup<ReadOnlySpan<char>> lookup = dictionary.GetAlternateLookup<string, int, ReadOnlySpan<char>>();
Assert.Same(dictionary, lookup.Dictionary);
Assert.Same(lookup.Dictionary, lookup.Dictionary);

string actualKey;
int value;

// Add to the dictionary and validate that the lookup reflects the changes
dictionary["123"] = 123;
Assert.True(lookup.ContainsKey("123".AsSpan()));
Assert.True(lookup.TryGetValue("123".AsSpan(), out value));
Assert.Equal(123, value);
Assert.Equal(123, lookup["123".AsSpan()]);
Assert.False(lookup.TryAdd("123".AsSpan(), 321));
Assert.True(lookup.Remove("123".AsSpan()));
Assert.False(dictionary.ContainsKey("123"));
Assert.Throws<KeyNotFoundException>(() => lookup["123".AsSpan()]);

// Add via the lookup and validate that the dictionary reflects the changes
Assert.True(lookup.TryAdd("123".AsSpan(), 123));
Assert.True(dictionary.ContainsKey("123"));
lookup.TryGetValue("123".AsSpan(), out value);
Assert.Equal(123, value);
Assert.False(lookup.Remove("321".AsSpan(), out actualKey, out value));
Assert.Null(actualKey);
Assert.Equal(0, value);
Assert.True(lookup.Remove("123".AsSpan(), out actualKey, out value));
Assert.Equal("123", actualKey);
Assert.Equal(123, value);

// Ensure that case-sensitivity of the comparer is respected
lookup["a".AsSpan()] = 42;
if (dictionary.Comparer.Equals(EqualityComparer<string>.Default) ||
dictionary.Comparer.Equals(StringComparer.Ordinal) ||
dictionary.Comparer.Equals(StringComparer.InvariantCulture) ||
dictionary.Comparer.Equals(StringComparer.CurrentCulture))
{
Assert.True(lookup.TryGetValue("a".AsSpan(), out actualKey, out value));
Assert.Equal("a", actualKey);
Assert.Equal(42, value);
Assert.True(lookup.TryAdd("A".AsSpan(), 42));
Assert.True(lookup.Remove("a".AsSpan()));
Assert.False(lookup.Remove("a".AsSpan()));
Assert.True(lookup.Remove("A".AsSpan()));
}
else
{
Assert.True(lookup.TryGetValue("A".AsSpan(), out actualKey, out value));
Assert.Equal("a", actualKey);
Assert.Equal(42, value);
Assert.False(lookup.TryAdd("A".AsSpan(), 42));
Assert.True(lookup.Remove("A".AsSpan()));
Assert.False(lookup.Remove("a".AsSpan()));
Assert.False(lookup.Remove("A".AsSpan()));
}

// Validate overwrites
lookup["a".AsSpan()] = 42;
Assert.Equal(42, dictionary["a"]);
lookup["a".AsSpan()] = 43;
Assert.True(lookup.Remove("a".AsSpan(), out actualKey, out value));
Assert.Equal("a", actualKey);
Assert.Equal(43, value);

// Test adding multiple entries via the lookup
for (int i = 0; i < 10; i++)
{
Assert.Equal(i, dictionary.Count);
Assert.True(lookup.TryAdd(i.ToString().AsSpan(), i));
Assert.False(lookup.TryAdd(i.ToString().AsSpan(), i));
}

Assert.Equal(10, dictionary.Count);

// Test that the lookup and the dictionary agree on what's in and not in
for (int i = -1; i <= 10; i++)
{
Assert.Equal(dictionary.TryGetValue(i.ToString(), out int dv), lookup.TryGetValue(i.ToString().AsSpan(), out int lv));
Assert.Equal(dv, lv);
}

// Test removing multiple entries via the lookup
for (int i = 9; i >= 0; i--)
{
Assert.True(lookup.Remove(i.ToString().AsSpan(), out actualKey, out value));
Assert.Equal(i.ToString(), actualKey);
Assert.Equal(i, value);
Assert.False(lookup.Remove(i.ToString().AsSpan(), out actualKey, out value));
Assert.Null(actualKey);
Assert.Equal(0, value);
Assert.Equal(i, dictionary.Count);
}
}

[Theory]
[InlineData(0)]
[InlineData(1)]
[InlineData(2)]
[InlineData(3)]
[InlineData(4)]
[InlineData(5)]
public void HashSet_GetAlternateLookup_OperationsMatchUnderlyingSet(int mode)
{
// Test with a variety of comparers to ensure that the alternate lookup is consistent with the underlying set
HashSet<string> set = new(mode switch
{
0 => StringComparer.Ordinal,
1 => StringComparer.OrdinalIgnoreCase,
2 => StringComparer.InvariantCulture,
3 => StringComparer.InvariantCultureIgnoreCase,
4 => StringComparer.CurrentCulture,
5 => StringComparer.CurrentCultureIgnoreCase,
_ => throw new ArgumentOutOfRangeException(nameof(mode))
});
HashSet<string>.AlternateLookup<ReadOnlySpan<char>> lookup = set.GetAlternateLookup<string, ReadOnlySpan<char>>();
Assert.Same(set, lookup.Set);
Assert.Same(lookup.Set, lookup.Set);

// Add to the set and validate that the lookup reflects the changes
Assert.True(set.Add("123"));
Assert.True(lookup.Contains("123".AsSpan()));
Assert.False(lookup.Add("123".AsSpan()));
Assert.True(lookup.Remove("123".AsSpan()));
Assert.False(set.Contains("123"));

// Add via the lookup and validate that the set reflects the changes
Assert.True(lookup.Add("123".AsSpan()));
Assert.True(set.Contains("123"));
lookup.TryGetValue("123".AsSpan(), out string value);
Assert.Equal("123", value);
Assert.False(lookup.Remove("321".AsSpan()));
Assert.True(lookup.Remove("123".AsSpan()));

// Ensure that case-sensitivity of the comparer is respected
Assert.True(lookup.Add("a"));
if (set.Comparer.Equals(StringComparer.Ordinal) ||
set.Comparer.Equals(StringComparer.InvariantCulture) ||
set.Comparer.Equals(StringComparer.CurrentCulture))
{
Assert.True(lookup.Add("A".AsSpan()));
Assert.True(lookup.Remove("a".AsSpan()));
Assert.False(lookup.Remove("a".AsSpan()));
Assert.True(lookup.Remove("A".AsSpan()));
}
else
{
Assert.False(lookup.Add("A".AsSpan()));
Assert.True(lookup.Remove("A".AsSpan()));
Assert.False(lookup.Remove("a".AsSpan()));
Assert.False(lookup.Remove("A".AsSpan()));
}

// Test the behavior of null vs "" in the set and lookup
Assert.True(set.Add(null));
Assert.True(set.Add(string.Empty));
Assert.True(set.Contains(null));
Assert.True(set.Contains(""));
Assert.True(lookup.Contains("".AsSpan()));
Assert.True(lookup.Remove("".AsSpan()));
Assert.Equal(1, set.Count);
Assert.False(lookup.Remove("".AsSpan()));
Assert.True(set.Remove(null));
Assert.Equal(0, set.Count);

// Test adding multiple entries via the lookup
for (int i = 0; i < 10; i++)
{
Assert.Equal(i, set.Count);
Assert.True(lookup.Add(i.ToString().AsSpan()));
Assert.False(lookup.Add(i.ToString().AsSpan()));
}

Assert.Equal(10, set.Count);

// Test that the lookup and the set agree on what's in and not in
for (int i = -1; i <= 10; i++)
{
Assert.Equal(set.TryGetValue(i.ToString(), out string dv), lookup.TryGetValue(i.ToString().AsSpan(), out string lv));
Assert.Equal(dv, lv);
}

// Test removing multiple entries via the lookup
for (int i = 9; i >= 0; i--)
{
Assert.True(lookup.Remove(i.ToString().AsSpan()));
Assert.False(lookup.Remove(i.ToString().AsSpan()));
Assert.Equal(i, set.Count);
}
}

[Fact]
public void Dictionary_NotCorruptedByThrowingComparer()
{
Dictionary<string, string> dict = new(new CreateThrowsComparer());

Assert.Equal(0, dict.Count);

Assert.Throws<FormatException>(() => dict.GetAlternateLookup<string, string, ReadOnlySpan<char>>().TryAdd("123".AsSpan(), "123"));
Assert.Throws<FormatException>(() => dict.GetAlternateLookup<ReadOnlySpan<char>>().TryAdd("123".AsSpan(), "123"));
Assert.Equal(0, dict.Count);

dict.Add("123", "123");
Expand All @@ -408,7 +162,7 @@ public void Dictionary_NotCorruptedByNullReturningComparer()

Assert.Equal(0, dict.Count);

Assert.ThrowsAny<ArgumentException>(() => dict.GetAlternateLookup<string, string, ReadOnlySpan<char>>().TryAdd("123".AsSpan(), "123"));
Assert.ThrowsAny<ArgumentException>(() => dict.GetAlternateLookup<ReadOnlySpan<char>>().TryAdd("123".AsSpan(), "123"));
Assert.Equal(0, dict.Count);

dict.Add("123", "123");
Expand All @@ -422,7 +176,7 @@ public void HashSet_NotCorruptedByThrowingComparer()

Assert.Equal(0, set.Count);

Assert.Throws<FormatException>(() => set.GetAlternateLookup<string, ReadOnlySpan<char>>().Add("123".AsSpan()));
Assert.Throws<FormatException>(() => set.GetAlternateLookup<ReadOnlySpan<char>>().Add("123".AsSpan()));
Assert.Equal(0, set.Count);

set.Add("123");
Expand Down
Loading

0 comments on commit 44b6b2a

Please sign in to comment.