Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 59 additions & 7 deletions src/libraries/System.Linq/src/System/Linq/Range.SpeedOpt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace System.Linq
{
Expand All @@ -17,13 +20,7 @@ public override IEnumerable<TResult> Select<TResult>(Func<int, TResult> selector
public int[] ToArray()
{
int[] array = new int[_end - _start];
int cur = _start;
for (int i = 0; i < array.Length; ++i)
{
array[i] = cur;
++cur;
}

InitializeSpan(array);
return array;
}

Expand Down Expand Up @@ -84,6 +81,61 @@ public int TryGetLast(out bool found)
found = true;
return _end - 1;
}

// Destination *must* be non-empty and match the range length
private void InitializeSpan(Span<int> destination)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be inlined. I'd expect this to regress for very short ranges due to the extra method call that wasn't there before.

{
if (destination.Length < Vector<int>.Count * 2)
{
int cur = _start;
for (int i = 0; i < destination.Length; i++)
{
destination[i] = cur;
cur++;
}
}
else
{
InitializeSpanCore(destination);
}
}

private void InitializeSpanCore(Span<int> destination)
{
int width = Vector<int>.Count;
int stride = Vector<int>.Count * 2;
int remainder = destination.Length % stride;

// Up to 16 elements which corresponds to AVX512
Vector<int> initMask = Unsafe.ReadUnaligned<Vector<int>>(
ref Unsafe.As<int, byte>(ref MemoryMarshal.GetReference(
(ReadOnlySpan<int>)new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 })));

Vector<int> mask = new Vector<int>(stride);
Vector<int> value = new Vector<int>(_start) + initMask;
Vector<int> value2 = value + new Vector<int>(width);

ref int pos = ref MemoryMarshal.GetReference(destination);
ref int limit = ref Unsafe.Add(ref pos, destination.Length - remainder);
while (!Unsafe.AreSame(ref pos, ref limit))
{
Unsafe.WriteUnaligned(ref Unsafe.As<int, byte>(ref pos), value);
Unsafe.WriteUnaligned(ref Unsafe.As<int, byte>(ref Unsafe.Add(ref pos, width)), value2);

value += mask;
value2 += mask;
pos = ref Unsafe.Add(ref pos, stride);
}

int cur = _start + (destination.Length - remainder);
int end = _end;
while (cur < end)
{
pos = cur;
pos = ref Unsafe.Add(ref pos, 1);
cur++;
}
}
}
}
}