@@ -530,11 +530,22 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
530530 raise ValueError ("Input array must be at least 1-d." )
531531 return tuple (xp .moveaxis (x , axis , 0 ))
532532
533+ # numpy 1.26 does not use the standard definition for sign on complex numbers
534+
535+ def sign (x : array , / , xp , ** kwargs ) -> array :
536+ if isdtype (x .dtype , 'complex floating' , xp = xp ):
537+ out = (x / xp .abs (x , ** kwargs ))[...]
538+ # sign(0) = 0 but the above formula would give nan
539+ out [x == 0 + 0j ] = 0 + 0j
540+ return out [()]
541+ else :
542+ return xp .sign (x , ** kwargs )
543+
533544__all__ = ['arange' , 'empty' , 'empty_like' , 'eye' , 'full' , 'full_like' ,
534545 'linspace' , 'ones' , 'ones_like' , 'zeros' , 'zeros_like' ,
535546 'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
536547 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
537548 'astype' , 'std' , 'var' , 'cumulative_sum' , 'clip' , 'permute_dims' ,
538549 'reshape' , 'argsort' , 'sort' , 'nonzero' , 'ceil' , 'floor' , 'trunc' ,
539550 'matmul' , 'matrix_transpose' , 'tensordot' , 'vecdot' , 'isdtype' ,
540- 'unstack' ]
551+ 'unstack' , 'sign' ]
0 commit comments