diff --git a/src/libraries/System.Memory/tests/Span/Contains.T.cs b/src/libraries/System.Memory/tests/Span/Contains.T.cs index a6846e5aa3e586..c1543c098e2547 100644 --- a/src/libraries/System.Memory/tests/Span/Contains.T.cs +++ b/src/libraries/System.Memory/tests/Span/Contains.T.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; +using System.Linq; using Xunit; namespace System.SpanTests @@ -193,5 +195,114 @@ public static void ContainsNull_String(string[] spanInput, bool expected) Span theStrings = spanInput; Assert.Equal(expected, theStrings.Contains(null)); } + + [Theory] + [InlineData(new int[] { 1, 2, 3, 4 }, 4, true)] + [InlineData(new int[] { 1, 2, 3, 4 }, 5, false)] + public static void Contains_Int32(int[] array, int value, bool expectedResult) + { + // Test with short Span + Span span = new Span(array); + bool result = span.Contains(value); + Assert.Equal(result, expectedResult); + + // Test with long Span + for (int i = 0; i < 10; i++) + array = array.Concat(array).ToArray(); + span = new Span(array); + result = span.Contains(value); + Assert.Equal(result, expectedResult); + } + + [Theory] + [InlineData(new long[] { 1, 2, 3, 4 }, 4, true)] + [InlineData(new long[] { 1, 2, 3, 4 }, 5, false)] + public static void Contains_Int64(long[] array, long value, bool expectedResult) + { + // Test with short Span + Span span = new Span(array); + bool result = span.Contains(value); + Assert.Equal(result, expectedResult); + + // Test with long Span + for (int i = 0; i < 10; i++) + array = array.Concat(array).ToArray(); + span = new Span(array); + result = span.Contains(value); + Assert.Equal(result, expectedResult); + } + + [Theory] + [InlineData(new byte[] { 1, 2, 3, 4 }, 4, true)] + [InlineData(new byte[] { 1, 2, 3, 4 }, 5, false)] + public static void Contains_Byte(byte[] array, byte value, bool expectedResult) + { + // Test with short Span + Span span = new Span(array); + bool result = span.Contains(value); + Assert.Equal(result, expectedResult); + + // Test with long Span + for (int i = 0; i < 10; i++) + array = array.Concat(array).ToArray(); + span = new Span(array); + result = span.Contains(value); + Assert.Equal(result, expectedResult); + } + + [Theory] + [InlineData(new char[] { 'a', 'b', 'c', 'd' }, 'd', true)] + [InlineData(new char[] { 'a', 'b', 'c', 'd' }, 'e', false)] + public static void Contains_Char(char[] array, char value, bool expectedResult) + { + // Test with short Span + Span span = new Span(array); + bool result = span.Contains(value); + Assert.Equal(result, expectedResult); + + // Test with long Span + for (int i = 0; i < 10; i++) + array = array.Concat(array).ToArray(); + span = new Span(array); + result = span.Contains(value); + Assert.Equal(result, expectedResult); + + } + + [Theory] + [InlineData(new float[] { 1, 2, 3, 4 }, 4, true)] + [InlineData(new float[] { 1, 2, 3, 4 }, 5, false)] + public static void Contains_Float(float[] array, float value, bool expectedResult) + { + // Test with short Span + Span span = new Span(array); + bool result = span.Contains(value); + Assert.Equal(result, expectedResult); + + // Test with long Span + for (int i = 0; i < 10; i++) + array = array.Concat(array).ToArray(); + span = new Span(array); + result = span.Contains(value); + Assert.Equal(result, expectedResult); + } + + [Theory] + [InlineData(new double[] { 1, 2, 3, 4 }, 4, true)] + [InlineData(new double[] { 1, 2, 3, 4 }, 5, false)] + public static void Contains_Double(double[] array, double value, bool expectedResult) + { + // Test with short Span + Span span = new Span(array); + bool result = span.Contains(value); + Assert.Equal(result, expectedResult); + + // Test with long Span + for (int i = 0; i < 10; i++) + array = array.Concat(array).ToArray(); + span = new Span(array); + result = span.Contains(value); + Assert.Equal(result, expectedResult); + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Array.cs b/src/libraries/System.Private.CoreLib/src/System/Array.cs index d305987cc03be5..dfaf5c53e9eabe 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Array.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Array.cs @@ -1232,18 +1232,28 @@ ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As(array)) } else if (Unsafe.SizeOf() == sizeof(int)) { - int result = SpanHelpers.IndexOf( - ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As(array)), startIndex), - Unsafe.As(ref value), - count); + int result = typeof(T).IsValueType + ? SpanHelpers.IndexOfValueType( + ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As(array)), startIndex), + Unsafe.As(ref value), + count) + : SpanHelpers.IndexOf( + ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As(array)), startIndex), + Unsafe.As(ref value), + count); return (result >= 0 ? startIndex : 0) + result; } else if (Unsafe.SizeOf() == sizeof(long)) { - int result = SpanHelpers.IndexOf( - ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As(array)), startIndex), - Unsafe.As(ref value), - count); + int result = typeof(T).IsValueType + ? SpanHelpers.IndexOfValueType( + ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As(array)), startIndex), + Unsafe.As(ref value), + count) + : SpanHelpers.IndexOf( + ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As(array)), startIndex), + Unsafe.As(ref value), + count); return (result >= 0 ? startIndex : 0) + result; } } diff --git a/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs b/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs index 562405483103d3..d3c0039fbef62e 100644 --- a/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs +++ b/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs @@ -279,6 +279,18 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(span)), ref Unsafe.As(ref MemoryMarshal.GetReference(span)), Unsafe.As(ref value), span.Length); + + if (Unsafe.SizeOf() == sizeof(int)) + return -1 != SpanHelpers.IndexOfValueType( + ref Unsafe.As(ref MemoryMarshal.GetReference(span)), + Unsafe.As(ref value), + span.Length); + + if (Unsafe.SizeOf() == sizeof(long)) + return -1 != SpanHelpers.IndexOfValueType( + ref Unsafe.As(ref MemoryMarshal.GetReference(span)), + Unsafe.As(ref value), + span.Length); } return SpanHelpers.Contains(ref MemoryMarshal.GetReference(span), value, span.Length); @@ -306,6 +318,18 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(span)), ref Unsafe.As(ref MemoryMarshal.GetReference(span)), Unsafe.As(ref value), span.Length); + + if (Unsafe.SizeOf() == sizeof(int)) + return -1 != SpanHelpers.IndexOfValueType( + ref Unsafe.As(ref MemoryMarshal.GetReference(span)), + Unsafe.As(ref value), + span.Length); + + if (Unsafe.SizeOf() == sizeof(long)) + return -1 != SpanHelpers.IndexOfValueType( + ref Unsafe.As(ref MemoryMarshal.GetReference(span)), + Unsafe.As(ref value), + span.Length); } return SpanHelpers.Contains(ref MemoryMarshal.GetReference(span), value, span.Length); @@ -332,6 +356,18 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(span)), ref Unsafe.As(ref MemoryMarshal.GetReference(span)), Unsafe.As(ref value), span.Length); + + if (Unsafe.SizeOf() == sizeof(int)) + return SpanHelpers.IndexOfValueType( + ref Unsafe.As(ref MemoryMarshal.GetReference(span)), + Unsafe.As(ref value), + span.Length); + + if (Unsafe.SizeOf() == sizeof(long)) + return SpanHelpers.IndexOfValueType( + ref Unsafe.As(ref MemoryMarshal.GetReference(span)), + Unsafe.As(ref value), + span.Length); } return SpanHelpers.IndexOf(ref MemoryMarshal.GetReference(span), value, span.Length); diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs index f5de74e1166e4a..a5c8f9ae79a602 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Numerics; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Runtime.Intrinsics; using Internal.Runtime.CompilerServices; @@ -291,6 +292,115 @@ public static unsafe bool Contains(ref T searchSpace, T value, int length) wh return true; } + internal static unsafe int IndexOfValueType(ref T searchSpace, T value, int length) where T : struct, IEquatable + { + Debug.Assert(length >= 0); + + nint index = 0; // Use nint for arithmetic to avoid unnecessary 64->32->64 truncations + if (Vector.IsHardwareAccelerated && Vector.IsTypeSupported && (Vector.Count * 2) <= length) + { + Vector valueVector = new Vector(value); + Vector compareVector = default; + Vector matchVector = default; + if ((uint)length % (uint)Vector.Count != 0) + { + // Number of elements is not a multiple of Vector.Count, so do one + // check and shift only enough for the remaining set to be a multiple + // of Vector.Count. + compareVector = Unsafe.As>(ref Unsafe.Add(ref searchSpace, index)); + matchVector = Vector.Equals(valueVector, compareVector); + if (matchVector != Vector.Zero) + { + goto VectorMatch; + } + index += length % Vector.Count; + length -= length % Vector.Count; + } + while (length > 0) + { + compareVector = Unsafe.As>(ref Unsafe.Add(ref searchSpace, index)); + matchVector = Vector.Equals(valueVector, compareVector); + if (matchVector != Vector.Zero) + { + goto VectorMatch; + } + index += Vector.Count; + length -= Vector.Count; + } + goto NotFound; + VectorMatch: + for (int i = 0; i < Vector.Count; i++) + if (compareVector[i].Equals(value)) + return (int)(index + i); + } + + while (length >= 8) + { + if (value.Equals(Unsafe.Add(ref searchSpace, index))) + goto Found; + if (value.Equals(Unsafe.Add(ref searchSpace, index + 1))) + goto Found1; + if (value.Equals(Unsafe.Add(ref searchSpace, index + 2))) + goto Found2; + if (value.Equals(Unsafe.Add(ref searchSpace, index + 3))) + goto Found3; + if (value.Equals(Unsafe.Add(ref searchSpace, index + 4))) + goto Found4; + if (value.Equals(Unsafe.Add(ref searchSpace, index + 5))) + goto Found5; + if (value.Equals(Unsafe.Add(ref searchSpace, index + 6))) + goto Found6; + if (value.Equals(Unsafe.Add(ref searchSpace, index + 7))) + goto Found7; + + length -= 8; + index += 8; + } + + while (length >= 4) + { + if (value.Equals(Unsafe.Add(ref searchSpace, index))) + goto Found; + if (value.Equals(Unsafe.Add(ref searchSpace, index + 1))) + goto Found1; + if (value.Equals(Unsafe.Add(ref searchSpace, index + 2))) + goto Found2; + if (value.Equals(Unsafe.Add(ref searchSpace, index + 3))) + goto Found3; + + length -= 4; + index += 4; + } + + while (length > 0) + { + if (value.Equals(Unsafe.Add(ref searchSpace, index))) + goto Found; + + index += 1; + length--; + } + NotFound: + return -1; + + Found: // Workaround for https://github.com/dotnet/runtime/issues/8795 + return (int)index; + Found1: + return (int)(index + 1); + Found2: + return (int)(index + 2); + Found3: + return (int)(index + 3); + Found4: + return (int)(index + 4); + Found5: + return (int)(index + 5); + Found6: + return (int)(index + 6); + Found7: + return (int)(index + 7); + } + public static unsafe int IndexOf(ref T searchSpace, T value, int length) where T : IEquatable { Debug.Assert(length >= 0);