@@ -706,6 +706,21 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
706706 axis = 0
707707 return torch .index_select (x , axis , indices , ** kwargs )
708708
709+ def sign (x : array , / ) -> array :
710+ # torch sign() does not support complex numbers and does not propagate
711+ # nans. See https://github.com/data-apis/array-api-compat/issues/136
712+ if x .dtype .is_complex :
713+ out = x / torch .abs (x )
714+ # sign(0) = 0 but the above formula would give nan
715+ out [x == 0 + 0j ] = 0 + 0j
716+ return out
717+ else :
718+ out = torch .sign (x )
719+ if x .dtype .is_floating_point :
720+ out [torch .isnan (x )] = torch .nan
721+ return out
722+
723+
709724__all__ = ['result_type' , 'can_cast' , 'permute_dims' , 'bitwise_invert' ,
710725 'newaxis' , 'conj' , 'add' , 'atan2' , 'bitwise_and' ,
711726 'bitwise_left_shift' , 'bitwise_or' , 'bitwise_right_shift' ,
@@ -719,6 +734,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
719734 'broadcast_arrays' , 'UniqueAllResult' , 'UniqueCountsResult' ,
720735 'UniqueInverseResult' , 'unique_all' , 'unique_counts' ,
721736 'unique_inverse' , 'unique_values' , 'matmul' , 'matrix_transpose' ,
722- 'vecdot' , 'tensordot' , 'isdtype' , 'take' ]
737+ 'vecdot' , 'tensordot' , 'isdtype' , 'take' , 'sign' ]
723738
724739_all_ignore = ['torch' , 'get_xp' ]
0 commit comments