Files
QWERTYkez.Mensura/QWERTYkez.Mensura/Extensions/CollectionsSqrtExtensions.cs
2026-06-10 11:58:39 +07:00

317 lines
13 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
using System.Runtime.Intrinsics.X86;
namespace QWERTYkez.Mensura.Extensions;
internal static partial class CollectionsSqrtExtensions
{
// === SqrtCore === SIMD
internal static unsafe void SqrtCore<T, R>(this ReadOnlySpan<T> units, int len, Span<R> destination)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
ReadOnlySpan<double> srcDouble = MemoryMarshal.Cast<T, double>(units);
Span<double> dstDouble = MemoryMarshal.Cast<R, double>(destination);
int i = 0;
ref double srcRef = ref MemoryMarshal.GetReference(srcDouble);
ref double dstRef = ref MemoryMarshal.GetReference(dstDouble);
// 1. ПУТЬ AVX (x64 Процессоры Intel/AMD) — обрабатываем по 4 элемента double
if (Avx.IsSupported && len >= 4)
{
int simdEnd = len & ~3;
for (; i < simdEnd; i += 4)
{
ref double currentSrc = ref Unsafe.Add(ref srcRef, i);
ref double currentDst = ref Unsafe.Add(ref dstRef, i);
// Получаем указатели через Unsafe.AsPointer
double* pSrc = (double*)Unsafe.AsPointer(ref currentSrc);
double* pDst = (double*)Unsafe.AsPointer(ref currentDst);
// Выровненная загрузка (требует, чтобы pSrc был кратен 32)
var v = Avx.LoadVector256(pSrc);
var sqrtV = Avx.Sqrt(v);
// Выровненное сохранение
Avx.Store(pDst, sqrtV);
}
}
// 2. ПУТЬ VECTOR (ARM64 / Apple Silicon / Старые CPU без AVX)
else if (Vector.IsHardwareAccelerated && len >= Vector<double>.Count)
{
int vCount = Vector<double>.Count;
int simdEnd = len & ~(vCount - 1);
for (; i < simdEnd; i += vCount)
{
// Используем Span для кроссплатформенного создания вектора
var v = new Vector<double>(srcDouble.Slice(i, vCount));
// Кроссплатформенный аппаратный корень (на ARM превратится в NEON инструкцию)
var sqrtV = Vector.SquareRoot(v);
// Копируем напрямую в целевой Span
sqrtV.CopyTo(dstDouble.Slice(i, vCount));
}
}
// 3. Хвост массива (или обычный расчет, если SIMD на процессоре недоступен)
for (; i < len; i++)
{
Unsafe.Add(ref dstRef, i) = Math.Sqrt(Unsafe.Add(ref srcRef, i));
}
}
internal static void SqrtCore<T, R>(this ReadOnlySpan<T?> units, int len, Span<R?> destination)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
// Получаем прямые ref-ссылки на начало буферов за 0 тактов процессора
ref var srcRef = ref MemoryMarshal.GetReference(units);
ref var dstRef = ref MemoryMarshal.GetReference(destination);
int i = 0;
int unrollEnd = len & ~3; // Граница развернутого цикла (кратная 4)
// 1. ОСНОВНОЙ ЦИКЛ: Конвейерная обработка по 4 элемента за итерацию
for (; i < unrollEnd; i += 4)
{
T? u0 = Unsafe.Add(ref srcRef, i);
T? u1 = Unsafe.Add(ref srcRef, i + 1);
T? u2 = Unsafe.Add(ref srcRef, i + 2);
T? u3 = Unsafe.Add(ref srcRef, i + 3);
// Получаем ref-ссылки на целевые ячейки типа R? (zero-cost адресация)
ref var d0 = ref Unsafe.Add(ref dstRef, i);
ref var d1 = ref Unsafe.Add(ref dstRef, i + 1);
ref var d2 = ref Unsafe.Add(ref dstRef, i + 2);
ref var d3 = ref Unsafe.Add(ref dstRef, i + 3);
// Считаем нативный корень прямо в регистрах и трансформируем тип из T в R по месту
d0 = u0.HasValue ? Math.Sqrt(u0.Value.ToDouble()).ToUnit<R>() : null;
d1 = u1.HasValue ? Math.Sqrt(u1.Value.ToDouble()).ToUnit<R>() : null;
d2 = u2.HasValue ? Math.Sqrt(u2.Value.ToDouble()).ToUnit<R>() : null;
d3 = u3.HasValue ? Math.Sqrt(u3.Value.ToDouble()).ToUnit<R>() : null;
}
// 2. ХВОСТ ЦИКЛА: Довычисляем остаток элементов (от 1 до 3 штук)
for (; i < len; i++)
{
T? unit = Unsafe.Add(ref srcRef, i);
ref var dst = ref Unsafe.Add(ref dstRef, i);
dst = unit.HasValue ? Math.Sqrt(unit.Value.ToDouble()).ToUnit<R>() : null;
}
}
// === ReadOnlySpan ===
internal static void Sqrt<T, R>(this ReadOnlySpan<T> units, Span<R> destination)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
if (units.IsEmpty) return;
int len = units.Length;
if (len > destination.Length)
throw new ArgumentException("Целевой буфер destination меньше исходного source.");
units.SqrtCore(len, destination);
}
internal static void Sqrt<T, R>(this ReadOnlySpan<T?> units, Span<R?> destination)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
if (units.IsEmpty) return;
int len = units.Length;
if (len > destination.Length)
throw new ArgumentException("Целевой буфер destination меньше исходного source.");
units.SqrtCore(len, destination);
}
// === Array ===
internal static R[] Sqrt<T, R>(this T[] units)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
if (units is null) return null!;
if (units.Length == 0) return [];
var result = new R[units.Length];
Sqrt(units, result);
return result;
}
internal static R?[] Sqrt<T, R>(this T?[] units)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
if (units is null) return null!;
if (units.Length == 0) return [];
var result = new R?[units.Length];
Sqrt(units, result);
return result;
}
// === List<Length> ===
internal static List<R> Sqrt<T, R>(this List<T> units)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
if (units is null) return null!;
int len = units.Count;
if (len == 0) return [];
var resultArray = new R[len];
Sqrt(CollectionsMarshal.AsSpan(units), resultArray);
return resultArray.WrapAsList<R, R>();
}
internal static List<R?> Sqrt<T, R>(this List<T?> units)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
if (units is null) return null!;
int len = units.Count;
if (len == 0) return [];
var resultArray = new R?[len];
Sqrt(CollectionsMarshal.AsSpan(units), resultArray);
return resultArray.WrapAsList<R, R>();
}
// === ICollection<Length> ===
internal static void Sqrt<T, R>(this ICollection<T> units, Span<R> destination)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
if (units is null) return;
int count = units.Count;
if (count == 0) return;
if (destination.Length < count)
throw new ArgumentException("Destination too short");
if (units is T[] array) { array.Sqrt(destination); return; }
if (units is List<T> list) { CollectionsMarshal.AsSpan(list).Sqrt(destination); return; }
int i = 0;
foreach (T item in units)
destination[i++] = Math.Sqrt(item.ToDouble()).ToUnit<R>();
}
internal static void Sqrt<T, R>(this ICollection<T?> units, Span<R?> destination)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
if (units is null) return;
int count = units.Count;
if (count == 0) return;
if (destination.Length < count)
throw new ArgumentException("Destination too short");
if (units is T?[] array) { array.Sqrt(destination); return; }
if (units is List<T?> list) { CollectionsMarshal.AsSpan(list).Sqrt(destination); return; }
int i = 0;
foreach (T? item in units)
destination[i++] = item.HasValue
? Math.Sqrt(item.Value.ToDouble()).ToUnit<R>() : null;
}
// === IReadOnlyCollection<Length> ===
internal static void Sqrt<T, R>(this IReadOnlyCollection<T> units, Span<R> destination)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
if (units is null) return;
int count = units.Count;
if (count == 0) return;
if (destination.Length < count)
throw new ArgumentException("Destination too short");
if (units is T[] array) { array.Sqrt(destination); return; }
if (units is List<T> list) { CollectionsMarshal.AsSpan(list).Sqrt(destination); return; }
int i = 0;
foreach (T item in units)
destination[i++] = Math.Sqrt(item.ToDouble()).ToUnit<R>();
}
internal static void Sqrt<T, R>(this IReadOnlyCollection<T?> units, Span<R?> destination)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
if (units is null) return;
int count = units.Count;
if (count == 0) return;
if (destination.Length < count)
throw new ArgumentException("Destination too short");
if (units is T?[] array) { array.Sqrt(destination); return; }
if (units is List<T?> list) { CollectionsMarshal.AsSpan(list).Sqrt(destination); return; }
int i = 0;
foreach (T? item in units)
destination[i++] = item.HasValue
? Math.Sqrt(item.Value.ToDouble()).ToUnit<R>() : null;
}
// === IEnumerable<Length> + yield ===
internal static IEnumerable<R> SqrtIterator<T, R>(this IEnumerable<T> units)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
foreach (var item in units)
yield return Math.Sqrt(item.ToDouble()).ToUnit<R>();
}
internal static IEnumerable<R?> SqrtNullableIterator<T, R>(this IEnumerable<T?> units)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
foreach (var item in units)
yield return item.HasValue
? Math.Sqrt(item.Value.ToDouble()).ToUnit<R>() : null;
}
// === IEnumerable<Length> ===
internal static IEnumerable<R> Sqrt<T, R>(this IEnumerable<T> units)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
if (units is null) return null!;
if (units is T[] array) return array.Sqrt<T, R>();
if (units is List<T> list) return list.Sqrt<T, R>();
if (units is ICollection<T> col)
{
var arr = col.ToArray();
arr.Sqrt(arr);
return arr.ReCast<T, R>();
}
if (units is IReadOnlyCollection<T> roc)
{
var arr = roc.ToArray();
arr.Sqrt(arr);
return arr.ReCast<T, R>();
}
else return SqrtIterator<T, R>(units);
}
internal static IEnumerable<R?> Sqrt<T, R>(this IEnumerable<T?> units)
where T : struct, IMensuraUnit, IEquatable<T>
where R : struct, IMensuraUnit, IEquatable<R>
{
if (units is null) return null!;
if (units is T?[] array) return array.Sqrt<T, R>();
if (units is List<T?> list) return list.Sqrt<T, R>();
if (units is ICollection<T?> col)
{
var arr = col.ToArray();
arr.Sqrt(arr);
return arr.ReCast<T, R>();
}
if (units is IReadOnlyCollection<T?> roc)
{
var arr = roc.ToArray();
arr.Sqrt(arr);
return arr.ReCast<T, R>();
}
else return SqrtNullableIterator<T, R>(units);
}
}