@@ -700,6 +700,21 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
700700 axis = 0
701701 return torch .index_select (x , axis , indices , ** kwargs )
702702
703+ def sign (x : array , / ) -> array :
704+ # torch sign() does not support complex numbers and does not propagate
705+ # nans. See https://github.com/data-apis/array-api-compat/issues/136
706+ if x .dtype .is_complex :
707+ out = x / torch .abs (x )
708+ # sign(0) = 0 but the above formula would give nan
709+ out [x == 0 + 0j ] = 0 + 0j
710+ return out
711+ else :
712+ out = torch .sign (x )
713+ if x .dtype .is_floating_point :
714+ out [torch .isnan (x )] = torch .nan
715+ return out
716+
717+
703718__all__ = ['result_type' , 'can_cast' , 'permute_dims' , 'bitwise_invert' ,
704719 'newaxis' , 'add' , 'atan2' , 'bitwise_and' , 'bitwise_left_shift' ,
705720 'bitwise_or' , 'bitwise_right_shift' , 'bitwise_xor' , 'divide' ,
@@ -713,6 +728,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
713728 'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
714729 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
715730 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ,
716- 'take' ]
731+ 'take' , 'sign' ]
717732
718733_all_ignore = ['torch' , 'get_xp' ]
0 commit comments