|
5 | 5 | using System.Runtime.CompilerServices; |
6 | 6 | using System.Runtime.Intrinsics; |
7 | 7 | using System.Runtime.Intrinsics.X86; |
| 8 | +using NumSharp.Utilities; |
8 | 9 |
|
9 | 10 | // ============================================================================= |
10 | 11 | // ILKernelGenerator.Where - IL-generated np.where(condition, x, y) kernels |
@@ -400,7 +401,7 @@ private static void EmitWhereV128BodyWithOffset<T>(ILGenerator il, LocalBuilder |
400 | 401 | private static void EmitWhereScalarElement<T>(ILGenerator il, LocalBuilder locI) where T : unmanaged |
401 | 402 | { |
402 | 403 | long elementSize = Unsafe.SizeOf<T>(); |
403 | | - var typeCode = GetNPTypeCode<T>(); |
| 404 | + var typeCode = InfoOf<T>.NPTypeCode; |
404 | 405 |
|
405 | 406 | // result[i] = cond[i] ? x[i] : y[i] |
406 | 407 | var lblFalse = il.DefineLabel(); |
@@ -449,51 +450,6 @@ private static void EmitWhereScalarElement<T>(ILGenerator il, LocalBuilder locI) |
449 | 450 | EmitStoreIndirect(il, typeCode); |
450 | 451 | } |
451 | 452 |
|
452 | | - private static NPTypeCode GetNPTypeCode<T>() where T : unmanaged |
453 | | - { |
454 | | - if (typeof(T) == typeof(bool)) return NPTypeCode.Boolean; |
455 | | - if (typeof(T) == typeof(byte)) return NPTypeCode.Byte; |
456 | | - if (typeof(T) == typeof(short)) return NPTypeCode.Int16; |
457 | | - if (typeof(T) == typeof(ushort)) return NPTypeCode.UInt16; |
458 | | - if (typeof(T) == typeof(int)) return NPTypeCode.Int32; |
459 | | - if (typeof(T) == typeof(uint)) return NPTypeCode.UInt32; |
460 | | - if (typeof(T) == typeof(long)) return NPTypeCode.Int64; |
461 | | - if (typeof(T) == typeof(ulong)) return NPTypeCode.UInt64; |
462 | | - if (typeof(T) == typeof(char)) return NPTypeCode.Char; |
463 | | - if (typeof(T) == typeof(float)) return NPTypeCode.Single; |
464 | | - if (typeof(T) == typeof(double)) return NPTypeCode.Double; |
465 | | - if (typeof(T) == typeof(decimal)) return NPTypeCode.Decimal; |
466 | | - return NPTypeCode.Empty; |
467 | | - } |
468 | | - |
469 | | - #endregion |
470 | | - |
471 | | - #region Mask Creation Methods |
472 | | - |
473 | | - private static MethodInfo GetMaskCreationMethod256(int elementSize) |
474 | | - { |
475 | | - return elementSize switch |
476 | | - { |
477 | | - 1 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_1Byte), BindingFlags.NonPublic | BindingFlags.Static)!, |
478 | | - 2 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_2Byte), BindingFlags.NonPublic | BindingFlags.Static)!, |
479 | | - 4 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_4Byte), BindingFlags.NonPublic | BindingFlags.Static)!, |
480 | | - 8 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV256_8Byte), BindingFlags.NonPublic | BindingFlags.Static)!, |
481 | | - _ => throw new NotSupportedException($"Element size {elementSize} not supported for SIMD where") |
482 | | - }; |
483 | | - } |
484 | | - |
485 | | - private static MethodInfo GetMaskCreationMethod128(int elementSize) |
486 | | - { |
487 | | - return elementSize switch |
488 | | - { |
489 | | - 1 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_1Byte), BindingFlags.NonPublic | BindingFlags.Static)!, |
490 | | - 2 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_2Byte), BindingFlags.NonPublic | BindingFlags.Static)!, |
491 | | - 4 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_4Byte), BindingFlags.NonPublic | BindingFlags.Static)!, |
492 | | - 8 => typeof(ILKernelGenerator).GetMethod(nameof(CreateMaskV128_8Byte), BindingFlags.NonPublic | BindingFlags.Static)!, |
493 | | - _ => throw new NotSupportedException($"Element size {elementSize} not supported for SIMD where") |
494 | | - }; |
495 | | - } |
496 | | - |
497 | 453 | #endregion |
498 | 454 |
|
499 | 455 | #region Inline Mask IL Emission |
@@ -560,8 +516,6 @@ private static MethodInfo GetMaskCreationMethod128(int elementSize) |
560 | 516 | private static readonly MethodInfo _v128GreaterThanByte = Array.Find(typeof(Vector128).GetMethods(), |
561 | 517 | m => m.Name == "GreaterThan" && m.IsGenericMethodDefinition)!.MakeGenericMethod(typeof(byte)); |
562 | 518 |
|
563 | | - private static readonly FieldInfo _v256ZeroULong = typeof(Vector256<ulong>).GetProperty("Zero")!.GetMethod!.IsStatic |
564 | | - ? null! : null!; // Use GetMethod call instead |
565 | 519 | private static readonly MethodInfo _v256GetZeroULong = typeof(Vector256<ulong>).GetProperty("Zero")!.GetMethod!; |
566 | 520 | private static readonly MethodInfo _v256GetZeroUInt = typeof(Vector256<uint>).GetProperty("Zero")!.GetMethod!; |
567 | 521 | private static readonly MethodInfo _v256GetZeroUShort = typeof(Vector256<ushort>).GetProperty("Zero")!.GetMethod!; |
@@ -716,205 +670,6 @@ private static void EmitInlineMaskCreationV128(ILGenerator il, int elementSize) |
716 | 670 |
|
717 | 671 | #endregion |
718 | 672 |
|
719 | | - #region Static Mask Creation Methods (fallback) |
720 | | - |
721 | | - /// <summary> |
722 | | - /// Create V256 mask from 32 bools for 1-byte elements. |
723 | | - /// </summary> |
724 | | - [MethodImpl(MethodImplOptions.AggressiveInlining)] |
725 | | - private static unsafe Vector256<byte> CreateMaskV256_1Byte(byte* bools) |
726 | | - { |
727 | | - var vec = Vector256.Load(bools); |
728 | | - var zero = Vector256<byte>.Zero; |
729 | | - var isZero = Vector256.Equals(vec, zero); |
730 | | - return Vector256.OnesComplement(isZero); |
731 | | - } |
732 | | - |
733 | | - /// <summary> |
734 | | - /// Create V256 mask from 16 bools for 2-byte elements. |
735 | | - /// Uses AVX2 vpmovzxbw instruction for single-instruction expansion. |
736 | | - /// </summary> |
737 | | - [MethodImpl(MethodImplOptions.AggressiveInlining)] |
738 | | - private static unsafe Vector256<ushort> CreateMaskV256_2Byte(byte* bools) |
739 | | - { |
740 | | - if (Avx2.IsSupported) |
741 | | - { |
742 | | - // Load 16 bytes into Vector128, zero-extend each byte to 16-bit |
743 | | - // vpmovzxbw: byte -> word (16 bytes -> 16 words) |
744 | | - var bytes128 = Vector128.Load(bools); |
745 | | - var expanded = Avx2.ConvertToVector256Int16(bytes128).AsUInt16(); |
746 | | - // Compare with zero: non-zero becomes 0xFFFF, zero stays 0 |
747 | | - return Vector256.GreaterThan(expanded, Vector256<ushort>.Zero); |
748 | | - } |
749 | | - |
750 | | - // Scalar fallback for non-AVX2 systems |
751 | | - return Vector256.Create( |
752 | | - bools[0] != 0 ? (ushort)0xFFFF : (ushort)0, |
753 | | - bools[1] != 0 ? (ushort)0xFFFF : (ushort)0, |
754 | | - bools[2] != 0 ? (ushort)0xFFFF : (ushort)0, |
755 | | - bools[3] != 0 ? (ushort)0xFFFF : (ushort)0, |
756 | | - bools[4] != 0 ? (ushort)0xFFFF : (ushort)0, |
757 | | - bools[5] != 0 ? (ushort)0xFFFF : (ushort)0, |
758 | | - bools[6] != 0 ? (ushort)0xFFFF : (ushort)0, |
759 | | - bools[7] != 0 ? (ushort)0xFFFF : (ushort)0, |
760 | | - bools[8] != 0 ? (ushort)0xFFFF : (ushort)0, |
761 | | - bools[9] != 0 ? (ushort)0xFFFF : (ushort)0, |
762 | | - bools[10] != 0 ? (ushort)0xFFFF : (ushort)0, |
763 | | - bools[11] != 0 ? (ushort)0xFFFF : (ushort)0, |
764 | | - bools[12] != 0 ? (ushort)0xFFFF : (ushort)0, |
765 | | - bools[13] != 0 ? (ushort)0xFFFF : (ushort)0, |
766 | | - bools[14] != 0 ? (ushort)0xFFFF : (ushort)0, |
767 | | - bools[15] != 0 ? (ushort)0xFFFF : (ushort)0 |
768 | | - ); |
769 | | - } |
770 | | - |
771 | | - /// <summary> |
772 | | - /// Create V256 mask from 8 bools for 4-byte elements. |
773 | | - /// Uses AVX2 vpmovzxbd instruction for single-instruction expansion. |
774 | | - /// </summary> |
775 | | - [MethodImpl(MethodImplOptions.AggressiveInlining)] |
776 | | - private static unsafe Vector256<uint> CreateMaskV256_4Byte(byte* bools) |
777 | | - { |
778 | | - if (Avx2.IsSupported) |
779 | | - { |
780 | | - // Load 8 bytes into low bytes of Vector128, zero-extend each byte to 32-bit |
781 | | - // vpmovzxbd: byte -> dword (8 bytes -> 8 dwords) |
782 | | - var bytes128 = Vector128.CreateScalar(*(ulong*)bools).AsByte(); |
783 | | - var expanded = Avx2.ConvertToVector256Int32(bytes128).AsUInt32(); |
784 | | - // Compare with zero: non-zero becomes 0xFFFF..., zero stays 0 |
785 | | - return Vector256.GreaterThan(expanded, Vector256<uint>.Zero); |
786 | | - } |
787 | | - |
788 | | - // Scalar fallback for non-AVX2 systems |
789 | | - return Vector256.Create( |
790 | | - bools[0] != 0 ? 0xFFFFFFFFu : 0u, |
791 | | - bools[1] != 0 ? 0xFFFFFFFFu : 0u, |
792 | | - bools[2] != 0 ? 0xFFFFFFFFu : 0u, |
793 | | - bools[3] != 0 ? 0xFFFFFFFFu : 0u, |
794 | | - bools[4] != 0 ? 0xFFFFFFFFu : 0u, |
795 | | - bools[5] != 0 ? 0xFFFFFFFFu : 0u, |
796 | | - bools[6] != 0 ? 0xFFFFFFFFu : 0u, |
797 | | - bools[7] != 0 ? 0xFFFFFFFFu : 0u |
798 | | - ); |
799 | | - } |
800 | | - |
801 | | - /// <summary> |
802 | | - /// Create V256 mask from 4 bools for 8-byte elements. |
803 | | - /// Uses AVX2 vpmovzxbq instruction for single-instruction expansion. |
804 | | - /// </summary> |
805 | | - [MethodImpl(MethodImplOptions.AggressiveInlining)] |
806 | | - private static unsafe Vector256<ulong> CreateMaskV256_8Byte(byte* bools) |
807 | | - { |
808 | | - if (Avx2.IsSupported) |
809 | | - { |
810 | | - // Load 4 bytes into low bytes of Vector128, zero-extend each byte to 64-bit |
811 | | - // vpmovzxbq: byte -> qword (4 bytes -> 4 qwords) |
812 | | - var bytes128 = Vector128.CreateScalar(*(uint*)bools).AsByte(); |
813 | | - var expanded = Avx2.ConvertToVector256Int64(bytes128).AsUInt64(); |
814 | | - // Compare with zero: non-zero becomes 0xFFFF..., zero stays 0 |
815 | | - return Vector256.GreaterThan(expanded, Vector256<ulong>.Zero); |
816 | | - } |
817 | | - |
818 | | - // Scalar fallback for non-AVX2 systems |
819 | | - return Vector256.Create( |
820 | | - bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, |
821 | | - bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, |
822 | | - bools[2] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, |
823 | | - bools[3] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul |
824 | | - ); |
825 | | - } |
826 | | - |
827 | | - /// <summary> |
828 | | - /// Create V128 mask from 16 bools for 1-byte elements. |
829 | | - /// </summary> |
830 | | - [MethodImpl(MethodImplOptions.AggressiveInlining)] |
831 | | - private static unsafe Vector128<byte> CreateMaskV128_1Byte(byte* bools) |
832 | | - { |
833 | | - var vec = Vector128.Load(bools); |
834 | | - var zero = Vector128<byte>.Zero; |
835 | | - var isZero = Vector128.Equals(vec, zero); |
836 | | - return Vector128.OnesComplement(isZero); |
837 | | - } |
838 | | - |
839 | | - /// <summary> |
840 | | - /// Create V128 mask from 8 bools for 2-byte elements. |
841 | | - /// Uses SSE4.1 pmovzxbw instruction for efficient expansion. |
842 | | - /// </summary> |
843 | | - [MethodImpl(MethodImplOptions.AggressiveInlining)] |
844 | | - private static unsafe Vector128<ushort> CreateMaskV128_2Byte(byte* bools) |
845 | | - { |
846 | | - if (Sse41.IsSupported) |
847 | | - { |
848 | | - // Load 8 bytes, zero-extend each to 16-bit |
849 | | - // pmovzxbw: byte -> word (8 bytes -> 8 words) |
850 | | - var bytes128 = Vector128.CreateScalar(*(ulong*)bools).AsByte(); |
851 | | - var expanded = Sse41.ConvertToVector128Int16(bytes128).AsUInt16(); |
852 | | - return Vector128.GreaterThan(expanded, Vector128<ushort>.Zero); |
853 | | - } |
854 | | - |
855 | | - // Scalar fallback |
856 | | - return Vector128.Create( |
857 | | - bools[0] != 0 ? (ushort)0xFFFF : (ushort)0, |
858 | | - bools[1] != 0 ? (ushort)0xFFFF : (ushort)0, |
859 | | - bools[2] != 0 ? (ushort)0xFFFF : (ushort)0, |
860 | | - bools[3] != 0 ? (ushort)0xFFFF : (ushort)0, |
861 | | - bools[4] != 0 ? (ushort)0xFFFF : (ushort)0, |
862 | | - bools[5] != 0 ? (ushort)0xFFFF : (ushort)0, |
863 | | - bools[6] != 0 ? (ushort)0xFFFF : (ushort)0, |
864 | | - bools[7] != 0 ? (ushort)0xFFFF : (ushort)0 |
865 | | - ); |
866 | | - } |
867 | | - |
868 | | - /// <summary> |
869 | | - /// Create V128 mask from 4 bools for 4-byte elements. |
870 | | - /// Uses SSE4.1 pmovzxbd instruction for efficient expansion. |
871 | | - /// </summary> |
872 | | - [MethodImpl(MethodImplOptions.AggressiveInlining)] |
873 | | - private static unsafe Vector128<uint> CreateMaskV128_4Byte(byte* bools) |
874 | | - { |
875 | | - if (Sse41.IsSupported) |
876 | | - { |
877 | | - // Load 4 bytes, zero-extend each to 32-bit |
878 | | - // pmovzxbd: byte -> dword (4 bytes -> 4 dwords) |
879 | | - var bytes128 = Vector128.CreateScalar(*(uint*)bools).AsByte(); |
880 | | - var expanded = Sse41.ConvertToVector128Int32(bytes128).AsUInt32(); |
881 | | - return Vector128.GreaterThan(expanded, Vector128<uint>.Zero); |
882 | | - } |
883 | | - |
884 | | - // Scalar fallback |
885 | | - return Vector128.Create( |
886 | | - bools[0] != 0 ? 0xFFFFFFFFu : 0u, |
887 | | - bools[1] != 0 ? 0xFFFFFFFFu : 0u, |
888 | | - bools[2] != 0 ? 0xFFFFFFFFu : 0u, |
889 | | - bools[3] != 0 ? 0xFFFFFFFFu : 0u |
890 | | - ); |
891 | | - } |
892 | | - |
893 | | - /// <summary> |
894 | | - /// Create V128 mask from 2 bools for 8-byte elements. |
895 | | - /// Uses SSE4.1 pmovzxbq instruction for efficient expansion. |
896 | | - /// </summary> |
897 | | - [MethodImpl(MethodImplOptions.AggressiveInlining)] |
898 | | - private static unsafe Vector128<ulong> CreateMaskV128_8Byte(byte* bools) |
899 | | - { |
900 | | - if (Sse41.IsSupported) |
901 | | - { |
902 | | - // Load 2 bytes, zero-extend each to 64-bit |
903 | | - // pmovzxbq: byte -> qword (2 bytes -> 2 qwords) |
904 | | - var bytes128 = Vector128.CreateScalar(*(ushort*)bools).AsByte(); |
905 | | - var expanded = Sse41.ConvertToVector128Int64(bytes128).AsUInt64(); |
906 | | - return Vector128.GreaterThan(expanded, Vector128<ulong>.Zero); |
907 | | - } |
908 | | - |
909 | | - // Scalar fallback |
910 | | - return Vector128.Create( |
911 | | - bools[0] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul, |
912 | | - bools[1] != 0 ? 0xFFFFFFFFFFFFFFFFul : 0ul |
913 | | - ); |
914 | | - } |
915 | | - |
916 | | - #endregion |
917 | | - |
918 | 673 | #region Scalar Fallback |
919 | 674 |
|
920 | 675 | /// <summary> |
|
0 commit comments