|
35 | 35 | BinaryElementwiseFunc |
36 | 36 | ) |
37 | 37 | import dpctl.tensor._tensor_impl as ti |
| 38 | +import dpctl.tensor as dpt |
| 39 | +import dpctl |
| 40 | + |
| 41 | +import numpy |
38 | 42 |
|
39 | 43 |
|
40 | 44 | __all__ = [ |
@@ -125,12 +129,27 @@ def _call_divide(src1, src2, dst, sycl_queue, depends=[]): |
125 | 129 | return vmi._div(sycl_queue, src1, src2, dst, depends) |
126 | 130 | return ti._divide(src1, src2, dst, sycl_queue, depends) |
127 | 131 |
|
| 132 | + def _call_divide_inplace(lhs, rhs, sycl_queue, depends=[]): |
| 133 | + """In place workaround until dpctl.tensor provides the functionality.""" |
| 134 | + |
| 135 | + # allocate temporary memory for out array |
| 136 | + out = dpt.empty_like(lhs, dtype=numpy.result_type((lhs.dtype, rhs.dtype))) |
| 137 | + |
| 138 | + # call a general callback |
| 139 | + div_ht_, div_ev_ = _call_divide(lhs, rhs, out, sycl_queue, depends) |
| 140 | + |
| 141 | + # store the result into left input array and return events |
| 142 | + cp_ht_, cp_ev_ = ti._copy_usm_ndarray_into_usm_ndarray(src=out, dst=lhs, sycl_queue=sycl_queue, depends=[div_ev_]) |
| 143 | + dpctl.SyclEvent.wait_for([div_ht_]) |
| 144 | + return (cp_ht_, cp_ev_) |
| 145 | + |
128 | 146 | # dpctl.tensor only works with usm_ndarray or scalar |
129 | 147 | x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1) |
130 | 148 | x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) |
131 | 149 | out_usm = None if out is None else dpnp.get_usm_ndarray(out) |
132 | 150 |
|
133 | | - func = BinaryElementwiseFunc("divide", ti._divide_result_type, _call_divide, _divide_docstring_) |
| 151 | + func = BinaryElementwiseFunc("divide", ti._divide_result_type, _call_divide, |
| 152 | + _divide_docstring_, _call_divide_inplace) |
134 | 153 | res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) |
135 | 154 | return dpnp_array._create_from_usm_ndarray(res_usm) |
136 | 155 |
|
@@ -208,6 +227,11 @@ def dpnp_subtract(x1, x2, out=None, order='K'): |
208 | 227 |
|
209 | 228 | """ |
210 | 229 |
|
| 230 | + # TODO: discuss with dpctl if the check is needed to be moved there |
| 231 | + if not dpnp.isscalar(x1) and not dpnp.isscalar(x2) and x1.dtype == x2.dtype == dpnp.bool: |
| 232 | + raise TypeError("DPNP boolean subtract, the `-` operator, is not supported, " |
| 233 | + "use the bitwise_xor, the `^` operator, or the logical_xor function instead.") |
| 234 | + |
211 | 235 | # dpctl.tensor only works with usm_ndarray or scalar |
212 | 236 | x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1) |
213 | 237 | x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) |
|
0 commit comments