Skip to content

Commit

Permalink
Vectorize String.Equals for OrdinalIgnoreCase (#77947)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephen Toub <stoub@microsoft.com>
  • Loading branch information
EgorBo and stephentoub committed Nov 11, 2022
1 parent e319104 commit 1980c7b
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Text.Unicode;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;

namespace System.Globalization
{
Expand Down Expand Up @@ -75,7 +76,62 @@ internal static int CompareStringIgnoreCaseNonAscii(ref char strA, int lengthA,
return OrdinalCasing.CompareStringIgnoreCase(ref strA, lengthA, ref strB, lengthB);
}

private static bool EqualsIgnoreCase_Vector128(ref char charA, ref char charB, int length)
{
Debug.Assert(length >= Vector128<ushort>.Count);
Debug.Assert(Vector128.IsHardwareAccelerated);

nuint lengthU = (nuint)length;
nuint lengthToExamine = lengthU - (nuint)Vector128<ushort>.Count;
nuint i = 0;
Vector128<ushort> vec1;
Vector128<ushort> vec2;
do
{
vec1 = Vector128.LoadUnsafe(ref Unsafe.As<char, ushort>(ref charA), i);
vec2 = Vector128.LoadUnsafe(ref Unsafe.As<char, ushort>(ref charB), i);

if (!Utf16Utility.AllCharsInVector128AreAscii(vec1 | vec2))
{
goto NON_ASCII;
}

if (!Utf16Utility.Vector128OrdinalIgnoreCaseAscii(vec1, vec2))
{
return false;
}

i += (nuint)Vector128<ushort>.Count;
} while (i <= lengthToExamine);

// Use scalar path for trailing elements
return i == lengthU || EqualsIgnoreCase(ref Unsafe.Add(ref charA, i), ref Unsafe.Add(ref charB, i), (int)(lengthU - i));

NON_ASCII:
if (Utf16Utility.AllCharsInVector128AreAscii(vec1) || Utf16Utility.AllCharsInVector128AreAscii(vec2))
{
// No need to use the fallback if one of the inputs is full-ASCII
return false;
}

// Fallback for Non-ASCII inputs
return CompareStringIgnoreCase(
ref Unsafe.Add(ref charA, i), (int)(lengthU - i),
ref Unsafe.Add(ref charB, i), (int)(lengthU - i)) == 0;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool EqualsIgnoreCase(ref char charA, ref char charB, int length)
{
if (!Vector128.IsHardwareAccelerated || length < Vector128<ushort>.Count)
{
return EqualsIgnoreCase_Scalar(ref charA, ref charB, length);
}

return EqualsIgnoreCase_Vector128(ref charA, ref charB, length);
}

internal static bool EqualsIgnoreCase_Scalar(ref char charA, ref char charB, int length)
{
IntPtr byteOffset = IntPtr.Zero;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Runtime.CompilerServices;
using System.Diagnostics;
using System.Runtime.Intrinsics;

namespace System.Text.Unicode
{
Expand Down Expand Up @@ -217,5 +218,43 @@ internal static bool UInt64OrdinalIgnoreCaseAscii(ulong valueA, ulong valueB)
indicator |= 0xFF7F_FF7F_FF7F_FF7Ful;
return (differentBits & indicator) == 0;
}

/// <summary>
/// Returns true iff the Vector128 represents 8 ASCII UTF-16 characters in machine endianness.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool AllCharsInVector128AreAscii(Vector128<ushort> vec)
{
return (vec & Vector128.Create(unchecked((ushort)~0x007F))) == Vector128<ushort>.Zero;
}

/// <summary>
/// Given two Vector128 that represent 8 ASCII UTF-16 characters each, returns true iff
/// the two inputs are equal using an ordinal case-insensitive comparison.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool Vector128OrdinalIgnoreCaseAscii(Vector128<ushort> vec1, Vector128<ushort> vec2)
{
// ASSUMPTION: Caller has validated that input values are ASCII.

// the 0x80 bit of each word of 'lowerIndicator' will be set iff the word has value >= 'A'
Vector128<sbyte> lowIndicator1 = Vector128.Create((sbyte)(0x80 - 'A')) + vec1.AsSByte();
Vector128<sbyte> lowIndicator2 = Vector128.Create((sbyte)(0x80 - 'A')) + vec2.AsSByte();

// the 0x80 bit of each word of 'combinedIndicator' will be set iff the word has value >= 'A' and <= 'Z'
Vector128<sbyte> combIndicator1 =
Vector128.LessThan(Vector128.Create(unchecked((sbyte)(('Z' - 'A') - 0x80))), lowIndicator1);
Vector128<sbyte> combIndicator2 =
Vector128.LessThan(Vector128.Create(unchecked((sbyte)(('Z' - 'A') - 0x80))), lowIndicator2);

// Convert both vectors to lower case by adding 0x20 bit for all [A-Z][a-z] characters
Vector128<sbyte> lcVec1 =
Vector128.AndNot(Vector128.Create((sbyte)0x20), combIndicator1) + vec1.AsSByte();
Vector128<sbyte> lcVec2 =
Vector128.AndNot(Vector128.Create((sbyte)0x20), combIndicator2) + vec2.AsSByte();

// Compare two lowercased vectors
return (lcVec1 ^ lcVec2) == Vector128<sbyte>.Zero;
}
}
}

0 comments on commit 1980c7b

Please sign in to comment.