Skip to content

Commit 3811960

Browse files
committed
fix(asanyarray,where): pure-float object[] promotion + hot-path short-circuits
Second-round code review caught one real bug and several minor efficiency issues. 1. Fix: FindCommonNumericType promoted pure-float object[] to double np.asanyarray(new object[]{1.5f, 2.5f}) returned Double instead of Single. Root cause: the early-exit `if (hasDouble || hasFloat) return typeof(double)` fired before the `uniqueCount == 1` check that preserves the original dtype. Removing the hasFloat arm lets the general path handle it: - Pure float32 -> uniqueCount == 1 -> returns firstType (Single) -- matches NumPy - int + float32 -> _FindCommonType_Scalar -> returns Double -- matches NumPy NEP50 - Pure float64 -> unchanged (still Double) - decimal-wins-everything early exit preserved. Two regression tests added: - ObjectArray_AllFloat_PreservesSingle - ObjectArray_MixedIntAndFloat32_PromotesToDouble 2. Perf: skip type promotion in np.where when x.dtype == y.dtype Previously _FindCommonType(x, y) always ran, even when both operands shared a dtype. Short-circuit to x.GetTypeCode in that case, saving one dict lookup + two astype traversals per call. The NEP50 lookup still runs when dtypes differ, preserving scalar+array promotion semantics. 3. Perf: skip broadcast_arrays when all three shapes already match broadcast_arrays allocates three fresh NDArrays plus helper Shape[]. For the common case of np.where(mask, arr, other_arr) where all three arrays share a shape, this is wasted. Skip it when condition.Shape == x.Shape == y.Shape (Shape == compares by dimensions). 4. Perf: cache Vector256/Vector128 generic MethodInfo EmitWhereV256BodyWithOffset and EmitWhereV128BodyWithOffset did Array.Find(typeof(Vector*).GetMethods(), ...) three times per call, each scanning ~100 methods. Per kernel generation (4-way unrolled + 1 remainder call = 5 calls), that was 15 reflection scans per T, or ~180 on first use across all 12 dtypes. Cached as six static readonly fields; only MakeGenericMethod(typeof(T)) runs per call. 5. Polish: doc + error message - where(NDArray) xmldoc was copy-pasted from the 3-arg overload ("Return elements chosen from x or y"); rewritten to describe nonzero semantics. - object[] NotSupportedException now names the actual problem ("element type is not a supported NumSharp dtype") instead of just reporting the length. Verified: 180 np.where + np.asanyarray tests pass on net8.0 + net10.0.
1 parent f0473d2 commit 3811960

4 files changed

Lines changed: 72 additions & 41 deletions

File tree

src/NumSharp.Core/APIs/np.where.cs

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ namespace NumSharp
77
public static partial class np
88
{
99
/// <summary>
10-
/// Return elements chosen from `x` or `y` depending on `condition`.
10+
/// Equivalent to <see cref="nonzero(NDArray)"/>: returns the indices where
11+
/// <paramref name="condition"/> is non-zero.
1112
/// </summary>
12-
/// <param name="condition">Where True, yield `x`, otherwise yield `y`.</param>
13-
/// <returns>Tuple of arrays with indices where condition is non-zero (equivalent to np.nonzero).</returns>
13+
/// <param name="condition">Input array. Non-zero entries yield their indices.</param>
14+
/// <returns>Tuple of arrays with indices where condition is non-zero, one per dimension.</returns>
1415
/// <remarks>https://numpy.org/doc/stable/reference/generated/numpy.where.html</remarks>
1516
public static NDArray<long>[] where(NDArray condition)
1617
{
@@ -62,17 +63,29 @@ public static NDArray where(NDArray condition, object x, object y)
6263
/// </summary>
6364
private static NDArray where_internal(NDArray condition, NDArray x, NDArray y)
6465
{
65-
// Broadcast all three arrays to common shape
66-
var broadcasted = broadcast_arrays(condition, x, y);
67-
var cond = broadcasted[0];
68-
var xArr = broadcasted[1];
69-
var yArr = broadcasted[2];
66+
// Skip broadcast_arrays (which allocates 3 NDArrays + helper arrays) when all three
67+
// already share a shape — the frequent case of np.where(mask, arr, other_arr).
68+
NDArray cond, xArr, yArr;
69+
if (condition.Shape == x.Shape && x.Shape == y.Shape)
70+
{
71+
cond = condition;
72+
xArr = x;
73+
yArr = y;
74+
}
75+
else
76+
{
77+
var broadcasted = broadcast_arrays(condition, x, y);
78+
cond = broadcasted[0];
79+
xArr = broadcasted[1];
80+
yArr = broadcasted[2];
81+
}
7082

71-
// Determine output dtype using existing type promotion system
72-
// _FindCommonType already handles NEP50: scalar+array → array wins
73-
var outType = _FindCommonType(x, y);
83+
// When x and y already agree, skip the NEP50 promotion lookup. Otherwise defer to
84+
// _FindCommonType which handles the scalar+array NEP50 rules.
85+
var outType = x.GetTypeCode == y.GetTypeCode
86+
? x.GetTypeCode
87+
: _FindCommonType(x, y);
7488

75-
// Convert x and y to output type if needed (required for kernel and iterator paths)
7689
if (xArr.GetTypeCode != outType)
7790
xArr = xArr.astype(outType, copy: false);
7891
if (yArr.GetTypeCode != outType)

src/NumSharp.Core/Backends/Kernels/ILKernelGenerator.Where.cs

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -251,18 +251,27 @@ private static void EmitWhereSIMDLoop<T>(ILGenerator il, LocalBuilder locI) wher
251251
il.MarkLabel(lblVectorLoopEnd);
252252
}
253253

254+
// Generic method definitions are cached once at class init; MakeGenericMethod is the
255+
// only per-T work needed during kernel generation.
256+
private static readonly MethodInfo _v256LoadGeneric = Array.Find(typeof(Vector256).GetMethods(),
257+
m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)!;
258+
private static readonly MethodInfo _v256StoreGeneric = Array.Find(typeof(Vector256).GetMethods(),
259+
m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)!;
260+
private static readonly MethodInfo _v256ConditionalSelectGeneric = Array.Find(typeof(Vector256).GetMethods(),
261+
m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)!;
262+
263+
private static readonly MethodInfo _v128LoadGeneric = Array.Find(typeof(Vector128).GetMethods(),
264+
m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)!;
265+
private static readonly MethodInfo _v128StoreGeneric = Array.Find(typeof(Vector128).GetMethods(),
266+
m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)!;
267+
private static readonly MethodInfo _v128ConditionalSelectGeneric = Array.Find(typeof(Vector128).GetMethods(),
268+
m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)!;
269+
254270
private static void EmitWhereV256BodyWithOffset<T>(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged
255271
{
256-
// Get Vector256 methods via reflection - need to find generic method definitions first
257-
var loadMethod = Array.Find(typeof(Vector256).GetMethods(),
258-
m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)!
259-
.MakeGenericMethod(typeof(T));
260-
var storeMethod = Array.Find(typeof(Vector256).GetMethods(),
261-
m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)!
262-
.MakeGenericMethod(typeof(T));
263-
var selectMethod = Array.Find(typeof(Vector256).GetMethods(),
264-
m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)!
265-
.MakeGenericMethod(typeof(T));
272+
var loadMethod = _v256LoadGeneric.MakeGenericMethod(typeof(T));
273+
var storeMethod = _v256StoreGeneric.MakeGenericMethod(typeof(T));
274+
var selectMethod = _v256ConditionalSelectGeneric.MakeGenericMethod(typeof(T));
266275

267276
// Load address: cond + (i + offset)
268277
il.Emit(OpCodes.Ldarg_0); // cond
@@ -327,16 +336,9 @@ private static void EmitWhereV256BodyWithOffset<T>(ILGenerator il, LocalBuilder
327336

328337
private static void EmitWhereV128BodyWithOffset<T>(ILGenerator il, LocalBuilder locI, long elementSize, long offset) where T : unmanaged
329338
{
330-
// Get Vector128 methods via reflection - need to find generic method definitions first
331-
var loadMethod = Array.Find(typeof(Vector128).GetMethods(),
332-
m => m.Name == "Load" && m.IsGenericMethodDefinition && m.GetParameters().Length == 1)!
333-
.MakeGenericMethod(typeof(T));
334-
var storeMethod = Array.Find(typeof(Vector128).GetMethods(),
335-
m => m.Name == "Store" && m.IsGenericMethodDefinition && m.GetParameters().Length == 2)!
336-
.MakeGenericMethod(typeof(T));
337-
var selectMethod = Array.Find(typeof(Vector128).GetMethods(),
338-
m => m.Name == "ConditionalSelect" && m.IsGenericMethodDefinition)!
339-
.MakeGenericMethod(typeof(T));
339+
var loadMethod = _v128LoadGeneric.MakeGenericMethod(typeof(T));
340+
var storeMethod = _v128StoreGeneric.MakeGenericMethod(typeof(T));
341+
var selectMethod = _v128ConditionalSelectGeneric.MakeGenericMethod(typeof(T));
340342

341343
// Load address: cond + (i + offset)
342344
il.Emit(OpCodes.Ldarg_0);

src/NumSharp.Core/Creation/np.asanyarray.cs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public static NDArray asanyarray(in object a, Type dtype = null) //todo support
3232
// supported element type.
3333
ret = ConvertNonGenericEnumerable(objArr);
3434
if (ret is null)
35-
throw new NotSupportedException($"Unable to resolve asanyarray for object array of length {objArr.Length}.");
35+
throw new NotSupportedException($"Unable to resolve asanyarray for object[] (length {objArr.Length}): element type is not a supported NumSharp dtype.");
3636
break;
3737
case Array array:
3838
ret = new NDArray(array);
@@ -206,8 +206,6 @@ private static Type FindCommonNumericType(List<object> items)
206206
{
207207
var span = CollectionsMarshal.AsSpan(items);
208208

209-
bool hasDouble = false;
210-
bool hasFloat = false;
211209
Type firstType = null;
212210

213211
// At most 12 unique NPTypeCode values exist; bound the stackalloc accordingly
@@ -221,12 +219,10 @@ private static Type FindCommonNumericType(List<object> items)
221219
var t = span[i].GetType();
222220
firstType ??= t;
223221

222+
// decimal wins everything in NumPy promotion
224223
if (t == typeof(decimal))
225224
return typeof(decimal);
226225

227-
if (t == typeof(double)) hasDouble = true;
228-
else if (t == typeof(float)) hasFloat = true;
229-
230226
var code = t.GetTypeCode();
231227
var bit = 1u << (int)code;
232228
if ((seenMask & bit) == 0)
@@ -236,9 +232,6 @@ private static Type FindCommonNumericType(List<object> items)
236232
}
237233
}
238234

239-
if (hasDouble || hasFloat)
240-
return typeof(double);
241-
242235
if (uniqueCount == 1)
243236
return firstType ?? typeof(double);
244237

test/NumSharp.UnitTest/Creation/np.asanyarray.Tests.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,29 @@ public void ObjectArray_Empty_ReturnsFloat64()
791791
result.dtype.Should().Be(typeof(double));
792792
}
793793

794+
[TestMethod]
795+
public void ObjectArray_AllFloat_PreservesSingle()
796+
{
797+
// Regression: an earlier FindCommonNumericType short-circuit promoted any float
798+
// to double. NumPy preserves float32 for homogeneous float32 inputs.
799+
var arr = new object[] { 1.5f, 2.5f, 3.5f };
800+
var result = np.asanyarray(arr);
801+
802+
result.dtype.Should().Be(typeof(float));
803+
result.Should().BeShaped(3).And.BeOfValues(1.5f, 2.5f, 3.5f);
804+
}
805+
806+
[TestMethod]
807+
public void ObjectArray_MixedIntAndFloat32_PromotesToDouble()
808+
{
809+
// int + float32 -> float64 per NumPy NEP50.
810+
var arr = new object[] { 1, 2.5f, 3 };
811+
var result = np.asanyarray(arr);
812+
813+
result.dtype.Should().Be(typeof(double));
814+
result.Should().BeShaped(3);
815+
}
816+
794817
#endregion
795818
}
796819
}

0 commit comments

Comments
 (0)