|
1 | | -from typing import Any |
| 1 | +# mypy: disable-error-code="no-redef" |
2 | 2 |
|
3 | | -# requires numpy < 2 |
4 | | -import numpy.array_api as np |
| 3 | +from types import ModuleType |
| 4 | +from typing import TypeAlias |
| 5 | + |
| 6 | +import numpy.array_api as np # type: ignore[import-not-found, unused-ignore] |
5 | 7 |
|
6 | 8 | import array_api_typing as xpt |
7 | 9 |
|
8 | | -### |
9 | | -# Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`. |
| 10 | +# DType aliases |
| 11 | +F32: TypeAlias = np.float32 |
| 12 | +I32: TypeAlias = np.int32 |
| 13 | + |
| 14 | +# Define NDArrays against which we can test the protocols |
| 15 | +nparr = np.eye(2) |
| 16 | +nparr_i32 = np.array([1], dtype=I32) |
| 17 | +nparr_f32 = np.array([1.0], dtype=F32) |
| 18 | +nparr_b = np.array([True], dtype=np.bool_) |
| 19 | + |
| 20 | +# ========================================================= |
| 21 | +# `xpt.HasArrayNamespace` |
| 22 | + |
| 23 | +_: xpt.HasArrayNamespace[ModuleType] = nparr |
| 24 | +_: xpt.HasArrayNamespace[ModuleType] = nparr_i32 |
| 25 | +_: xpt.HasArrayNamespace[ModuleType] = nparr_f32 |
| 26 | +_: xpt.HasArrayNamespace[ModuleType] = nparr_b |
| 27 | + |
| 28 | +# Check `__array_namespace__` method |
| 29 | +a_ns: xpt.HasArrayNamespace[ModuleType] = nparr |
| 30 | +ns: ModuleType = a_ns.__array_namespace__() |
10 | 31 |
|
11 | | -arr = np.eye(2) |
12 | | -arr_namespace: xpt.HasArrayNamespace[Any] = arr |
| 32 | +# Incorrect values are caught when using `__array_namespace__` and |
| 33 | +# backpropagated to the type of `a_ns` |
| 34 | +_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught |
0 commit comments