|
5 | 5 | ) |
6 | 6 |
|
7 | 7 | from types import ModuleType |
8 | | -from typing import Literal, Protocol |
| 8 | +from typing import Literal, Protocol, Self |
9 | 9 | from typing_extensions import TypeVar |
10 | 10 |
|
11 | 11 | NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType) |
@@ -77,10 +77,32 @@ def device(self) -> object: # TODO: more specific type |
77 | 77 | ... |
78 | 78 |
|
79 | 79 |
|
| 80 | +class HasMatrixTranspose(Protocol): |
| 81 | + """Protocol for array classes that have a matrix transpose attribute.""" |
| 82 | + |
| 83 | + @property |
| 84 | + def mT(self) -> Self: # noqa: N802 |
| 85 | + """Transpose of a matrix (or a stack of matrices). |
| 86 | +
|
| 87 | + If an array instance has fewer than two dimensions, an error should be |
| 88 | + raised. |
| 89 | +
|
| 90 | + Returns: |
| 91 | + Self: array whose last two dimensions (axes) are permuted in reverse |
| 92 | + order relative to original array (i.e., for an array instance |
| 93 | + having shape `(..., M, N)`, the returned array must have shape |
| 94 | + `(..., N, M))`. The returned array must have the same data type |
| 95 | + as the original array. |
| 96 | +
|
| 97 | + """ |
| 98 | + ... |
| 99 | + |
| 100 | + |
80 | 101 | class Array( |
81 | 102 | # ------ Attributes ------- |
82 | 103 | HasDType[DTypeT_co], |
83 | 104 | HasDevice, |
| 105 | + HasMatrixTranspose, |
84 | 106 | # ------- Methods --------- |
85 | 107 | HasArrayNamespace[NamespaceT_co], |
86 | 108 | # ------------------------- |
|
0 commit comments