|
8 | 8 | from typing_extensions import TypeVar |
9 | 9 |
|
10 | 10 | NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType) |
| 11 | +DTypeT_co = TypeVar("DTypeT_co", covariant=True) |
11 | 12 |
|
12 | 13 |
|
13 | 14 | class HasArrayNamespace(Protocol[NamespaceT_co]): |
@@ -57,15 +58,28 @@ def __array_namespace__( |
57 | 58 | ... |
58 | 59 |
|
59 | 60 |
|
| 61 | +class HasDType(Protocol[DTypeT_co]): |
| 62 | + """Protocol for array classes that have a data type attribute.""" |
| 63 | + |
| 64 | + @property |
| 65 | + def dtype(self, /) -> DTypeT_co: |
| 66 | + """Data type of the array elements.""" |
| 67 | + ... |
| 68 | + |
| 69 | + |
60 | 70 | class Array( |
61 | 71 | HasArrayNamespace[NamespaceT_co], |
| 72 | + # ------ Attributes ------- |
| 73 | + HasDType[DTypeT_co], |
62 | 74 | # ------------------------- |
63 | | - Protocol[NamespaceT_co], |
| 75 | + Protocol[DTypeT_co, NamespaceT_co], |
64 | 76 | ): |
65 | 77 | """Array API specification for array object attributes and methods. |
66 | 78 |
|
67 | | - The type is: ``Array[+NamespaceT = ModuleType] = Array[NamespaceT]`` where: |
| 79 | + The type is: ``Array[+DTypeT, +NamespaceT = ModuleType] = Array[DTypeT, |
| 80 | + NamespaceT]`` where: |
68 | 81 |
|
| 82 | + - `DTypeT` is the data type of the array elements. |
69 | 83 | - `NamespaceT` is the type of the array namespace. It defaults to |
70 | 84 | `ModuleType`, which is the most common form of array namespace (e.g., |
71 | 85 | `numpy`, `cupy`, etc.). However, it can be any type, e.g. a |
|
0 commit comments