Skip to content

Commit 18c3d61

Browse files
Add missing _place_impl() to _copy_utils.py
1 parent 23164ac commit 18c3d61

1 file changed

Lines changed: 104 additions & 0 deletions

File tree

dpctl_ext/tensor/_copy_utils.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,110 @@ def _prepare_indices_arrays(inds, q, usm_type):
310310
return inds
311311

312312

313+
def _place_impl(ary, ary_mask, vals, axis=0):
314+
"""
315+
Extract elements of ary by applying mask starting from slot
316+
dimension axis.
317+
"""
318+
if not isinstance(ary, dpt.usm_ndarray):
319+
raise TypeError(
320+
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
321+
)
322+
if isinstance(ary_mask, dpt.usm_ndarray):
323+
exec_q = dpctl.utils.get_execution_queue(
324+
(
325+
ary.sycl_queue,
326+
ary_mask.sycl_queue,
327+
)
328+
)
329+
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
330+
(
331+
ary.usm_type,
332+
ary_mask.usm_type,
333+
)
334+
)
335+
if exec_q is None:
336+
raise dpctl.utils.ExecutionPlacementError(
337+
"arrays have different associated queues. "
338+
"Use `y.to_device(x.device)` to migrate."
339+
)
340+
elif isinstance(ary_mask, np.ndarray):
341+
exec_q = ary.sycl_queue
342+
coerced_usm_type = ary.usm_type
343+
ary_mask = dpt.asarray(
344+
ary_mask, usm_type=coerced_usm_type, sycl_queue=exec_q
345+
)
346+
else:
347+
raise TypeError(
348+
"Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
349+
f"{type(ary_mask)}"
350+
)
351+
if exec_q is not None:
352+
if not isinstance(vals, dpt.usm_ndarray):
353+
vals = dpt.asarray(
354+
vals,
355+
dtype=ary.dtype,
356+
usm_type=coerced_usm_type,
357+
sycl_queue=exec_q,
358+
)
359+
else:
360+
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
361+
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
362+
(
363+
coerced_usm_type,
364+
vals.usm_type,
365+
)
366+
)
367+
if exec_q is None:
368+
raise dpctl.utils.ExecutionPlacementError(
369+
"arrays have different associated queues. "
370+
"Use `Y.to_device(X.device)` to migrate."
371+
)
372+
ary_nd = ary.ndim
373+
pp = normalize_axis_index(operator.index(axis), ary_nd)
374+
mask_nd = ary_mask.ndim
375+
if pp < 0 or pp + mask_nd > ary_nd:
376+
raise ValueError(
377+
"Parameter p is inconsistent with input array dimensions"
378+
)
379+
mask_nelems = ary_mask.size
380+
cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64
381+
cumsum = dpt.empty(
382+
mask_nelems,
383+
dtype=cumsum_dt,
384+
usm_type=coerced_usm_type,
385+
device=ary_mask.device,
386+
)
387+
exec_q = cumsum.sycl_queue
388+
_manager = dpctl.utils.SequentialOrderManager[exec_q]
389+
dep_ev = _manager.submitted_events
390+
mask_count = ti.mask_positions(
391+
ary_mask, cumsum, sycl_queue=exec_q, depends=dep_ev
392+
)
393+
expected_vals_shape = (
394+
ary.shape[:pp] + (mask_count,) + ary.shape[pp + mask_nd :]
395+
)
396+
if vals.dtype == ary.dtype:
397+
rhs = vals
398+
else:
399+
rhs = dpt.astype(vals, ary.dtype)
400+
rhs = dpt.broadcast_to(rhs, expected_vals_shape)
401+
if mask_nelems == 0:
402+
return
403+
dep_ev = _manager.submitted_events
404+
hev, pl_ev = ti._place(
405+
dst=ary,
406+
cumsum=cumsum,
407+
axis_start=pp,
408+
axis_end=pp + mask_nd,
409+
rhs=rhs,
410+
sycl_queue=exec_q,
411+
depends=dep_ev,
412+
)
413+
_manager.add_event_pair(hev, pl_ev)
414+
return
415+
416+
313417
def _put_multi_index(ary, inds, p, vals, mode=0):
314418
if not isinstance(ary, dpt.usm_ndarray):
315419
raise TypeError(

0 commit comments

Comments
 (0)