Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ public class DoubleGenericTensorPrimitives : GenericFloatingPointNumberTensorPri
public class SingleGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests<float> { }
public class HalfGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests<Half>
{
protected override void AssertEqualTolerance(Half expected, Half actual) => AssertEqualTolerance(expected, actual, Half.CreateTruncating(0.001));
protected override void AssertEqualTolerance(Half expected, Half actual, Half? tolerance = null) =>
base.AssertEqualTolerance(expected, actual, tolerance ?? Half.CreateTruncating(0.001));
}
public class NFloatGenericTensorPrimitives : GenericFloatingPointNumberTensorPrimitivesTests<NFloat> { }

Expand Down Expand Up @@ -158,9 +159,9 @@ public static IEnumerable<object[]> SpanDestinationFunctionsToTest()
yield return new object[] { new SpanDestinationDelegate(TensorPrimitives.Log10P1), new Func<T, T>(T.Log10P1) };
yield return new object[] { new SpanDestinationDelegate(TensorPrimitives.RadiansToDegrees), new Func<T, T>(T.RadiansToDegrees) };
yield return new object[] { new SpanDestinationDelegate(TensorPrimitives.Reciprocal), new Func<T, T>(f => T.One / f) };
yield return new object[] { new SpanDestinationDelegate(TensorPrimitives.ReciprocalEstimate), new Func<T, T>(T.ReciprocalEstimate) };
yield return new object[] { new SpanDestinationDelegate(TensorPrimitives.ReciprocalEstimate), new Func<T, T>(T.ReciprocalEstimate), T.CreateTruncating(1.171875e-02) };
yield return new object[] { new SpanDestinationDelegate(TensorPrimitives.ReciprocalSqrt), new Func<T, T>(f => T.One / T.Sqrt(f)) };
yield return new object[] { new SpanDestinationDelegate(TensorPrimitives.ReciprocalSqrtEstimate), new Func<T, T>(T.ReciprocalSqrtEstimate) };
yield return new object[] { new SpanDestinationDelegate(TensorPrimitives.ReciprocalSqrtEstimate), new Func<T, T>(T.ReciprocalSqrtEstimate), T.CreateTruncating(1.171875e-02) };
yield return new object[] { new SpanDestinationDelegate(TensorPrimitives.Round), new Func<T, T>(T.Round) };
yield return new object[] { new SpanDestinationDelegate(TensorPrimitives.Sin), new Func<T, T>(T.Sin) };
yield return new object[] { new SpanDestinationDelegate(TensorPrimitives.Sinh), new Func<T, T>(T.Sinh) };
Expand All @@ -174,8 +175,7 @@ public static IEnumerable<object[]> SpanDestinationFunctionsToTest()

[Theory]
[MemberData(nameof(SpanDestinationFunctionsToTest))]
[ActiveIssue("https://github.com/dotnet/runtime/issues/97297")]
public void SpanDestinationFunctions_AllLengths(SpanDestinationDelegate tensorPrimitivesMethod, Func<T, T> expectedMethod)
public void SpanDestinationFunctions_AllLengths(SpanDestinationDelegate tensorPrimitivesMethod, Func<T, T> expectedMethod, T? tolerance = null)
{
Assert.All(Helpers.TensorLengthsIncluding0, tensorLength =>
{
Expand All @@ -186,15 +186,14 @@ public void SpanDestinationFunctions_AllLengths(SpanDestinationDelegate tensorPr

for (int i = 0; i < tensorLength; i++)
{
AssertEqualTolerance(expectedMethod(x[i]), destination[i]);
AssertEqualTolerance(expectedMethod(x[i]), destination[i], tolerance);
}
});
}

[Theory]
[MemberData(nameof(SpanDestinationFunctionsToTest))]
[ActiveIssue("https://github.com/dotnet/runtime/issues/97297")]
public void SpanDestinationFunctions_InPlace(SpanDestinationDelegate tensorPrimitivesMethod, Func<T, T> expectedMethod)
public void SpanDestinationFunctions_InPlace(SpanDestinationDelegate tensorPrimitivesMethod, Func<T, T> expectedMethod, T? tolerance = null)
{
Assert.All(Helpers.TensorLengthsIncluding0, tensorLength =>
{
Expand All @@ -205,15 +204,14 @@ public void SpanDestinationFunctions_InPlace(SpanDestinationDelegate tensorPrimi

for (int i = 0; i < tensorLength; i++)
{
AssertEqualTolerance(expectedMethod(xOrig[i]), x[i]);
AssertEqualTolerance(expectedMethod(xOrig[i]), x[i], tolerance);
}
});
}

[Theory]
[MemberData(nameof(SpanDestinationFunctionsToTest))]
[ActiveIssue("https://github.com/dotnet/runtime/issues/97297")]
public void SpanDestinationFunctions_SpecialValues(SpanDestinationDelegate tensorPrimitivesMethod, Func<T, T> expectedMethod)
public void SpanDestinationFunctions_SpecialValues(SpanDestinationDelegate tensorPrimitivesMethod, Func<T, T> expectedMethod, T? tolerance = null)
{
Assert.All(Helpers.TensorLengths, tensorLength =>
{
Expand All @@ -225,16 +223,15 @@ public void SpanDestinationFunctions_SpecialValues(SpanDestinationDelegate tenso
tensorPrimitivesMethod(x.Span, destination.Span);
for (int i = 0; i < tensorLength; i++)
{
AssertEqualTolerance(expectedMethod(x[i]), destination[i]);
AssertEqualTolerance(expectedMethod(x[i]), destination[i], tolerance);
}
}, x);
});
}

[Theory]
[MemberData(nameof(SpanDestinationFunctionsToTest))]
[ActiveIssue("https://github.com/dotnet/runtime/issues/97297")]
public void SpanDestinationFunctions_ValueRange(SpanDestinationDelegate tensorPrimitivesMethod, Func<T, T> expectedMethod)
public void SpanDestinationFunctions_ValueRange(SpanDestinationDelegate tensorPrimitivesMethod, Func<T, T> expectedMethod, T? tolerance = null)
{
Assert.All(VectorLengthAndIteratedRange(ConvertFromSingle(-100f), ConvertFromSingle(100f), ConvertFromSingle(3f)), arg =>
{
Expand All @@ -247,14 +244,15 @@ public void SpanDestinationFunctions_ValueRange(SpanDestinationDelegate tensorPr
T expected = expectedMethod(arg.Element);
foreach (T actual in dest)
{
AssertEqualTolerance(expected, actual);
AssertEqualTolerance(expected, actual, tolerance);
}
});
}

#pragma warning disable xUnit1026 // Theory methods should use all of their parameters
[Theory]
[MemberData(nameof(SpanDestinationFunctionsToTest))]
public void SpanDestinationFunctions_ThrowsForTooShortDestination(SpanDestinationDelegate tensorPrimitivesMethod, Func<T, T> _)
public void SpanDestinationFunctions_ThrowsForTooShortDestination(SpanDestinationDelegate tensorPrimitivesMethod, Func<T, T> _, T? __ = null)
{
Assert.All(Helpers.TensorLengths, tensorLength =>
{
Expand All @@ -267,12 +265,13 @@ public void SpanDestinationFunctions_ThrowsForTooShortDestination(SpanDestinatio

[Theory]
[MemberData(nameof(SpanDestinationFunctionsToTest))]
public void SpanDestinationFunctions_ThrowsForOverlapppingInputsWithOutputs(SpanDestinationDelegate tensorPrimitivesMethod, Func<T, T> _)
public void SpanDestinationFunctions_ThrowsForOverlapppingInputsWithOutputs(SpanDestinationDelegate tensorPrimitivesMethod, Func<T, T> _, T? __ = null)
{
T[] array = new T[10];
AssertExtensions.Throws<ArgumentException>("destination", () => tensorPrimitivesMethod(array.AsSpan(1, 2), array.AsSpan(0, 2)));
AssertExtensions.Throws<ArgumentException>("destination", () => tensorPrimitivesMethod(array.AsSpan(1, 2), array.AsSpan(2, 2)));
}
#pragma warning restore xUnit1026
#endregion

#region Span,Span -> Destination
Expand Down Expand Up @@ -1639,10 +1638,10 @@ protected override T NextRandom()
}
}

protected override void AssertEqualTolerance(T expected, T actual) => AssertEqualTolerance(expected, actual, T.CreateTruncating(0.0001));

protected override void AssertEqualTolerance(T expected, T actual, T tolerance)
protected override void AssertEqualTolerance(T expected, T actual, T? tolerance = null)
{
tolerance ??= T.CreateTruncating(0.0001);

T diff = T.Abs(expected - actual);
if (diff > tolerance && diff > T.Max(T.Abs(expected), T.Abs(actual)) * tolerance)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ protected override float MinMagnitude(float x, float y)

protected override float NextRandom() => (float)((Random.NextDouble() * 2) - 1); // For testing purposes, get a mix of negative and positive values.

protected override void AssertEqualTolerance(float expected, float actual) => AssertEqualTolerance(expected, actual, 0.0001f);

protected override void AssertEqualTolerance(float expected, float actual, float tolerance)
protected override void AssertEqualTolerance(float expected, float actual, float? tolerance = null)
{
tolerance ??= 0.0001f;

double diff = Math.Abs((double)expected - (double)actual);
if (diff > tolerance && diff > Math.Max(Math.Abs(expected), Math.Abs(actual)) * tolerance)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ public abstract class TensorPrimitivesTests<T> where T : unmanaged, IEquatable<T

protected abstract T NextRandom();

protected abstract void AssertEqualTolerance(T expected, T actual);

protected abstract void AssertEqualTolerance(T expected, T actual, T tolerance);
protected abstract void AssertEqualTolerance(T expected, T actual, T? tolerance = null);

protected abstract IEnumerable<(int Length, T Element)> VectorLengthAndIteratedRange(T min, T max, T increment);

Expand Down