@@ -310,6 +310,110 @@ def _prepare_indices_arrays(inds, q, usm_type):
310310 return inds
311311
312312
313+ def _place_impl (ary , ary_mask , vals , axis = 0 ):
314+ """
315+ Extract elements of ary by applying mask starting from slot
316+ dimension axis.
317+ """
318+ if not isinstance (ary , dpt .usm_ndarray ):
319+ raise TypeError (
320+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
321+ )
322+ if isinstance (ary_mask , dpt .usm_ndarray ):
323+ exec_q = dpctl .utils .get_execution_queue (
324+ (
325+ ary .sycl_queue ,
326+ ary_mask .sycl_queue ,
327+ )
328+ )
329+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
330+ (
331+ ary .usm_type ,
332+ ary_mask .usm_type ,
333+ )
334+ )
335+ if exec_q is None :
336+ raise dpctl .utils .ExecutionPlacementError (
337+ "arrays have different associated queues. "
338+ "Use `y.to_device(x.device)` to migrate."
339+ )
340+ elif isinstance (ary_mask , np .ndarray ):
341+ exec_q = ary .sycl_queue
342+ coerced_usm_type = ary .usm_type
343+ ary_mask = dpt .asarray (
344+ ary_mask , usm_type = coerced_usm_type , sycl_queue = exec_q
345+ )
346+ else :
347+ raise TypeError (
348+ "Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
349+ f"{ type (ary_mask )} "
350+ )
351+ if exec_q is not None :
352+ if not isinstance (vals , dpt .usm_ndarray ):
353+ vals = dpt .asarray (
354+ vals ,
355+ dtype = ary .dtype ,
356+ usm_type = coerced_usm_type ,
357+ sycl_queue = exec_q ,
358+ )
359+ else :
360+ exec_q = dpctl .utils .get_execution_queue ((exec_q , vals .sycl_queue ))
361+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
362+ (
363+ coerced_usm_type ,
364+ vals .usm_type ,
365+ )
366+ )
367+ if exec_q is None :
368+ raise dpctl .utils .ExecutionPlacementError (
369+ "arrays have different associated queues. "
370+ "Use `Y.to_device(X.device)` to migrate."
371+ )
372+ ary_nd = ary .ndim
373+ pp = normalize_axis_index (operator .index (axis ), ary_nd )
374+ mask_nd = ary_mask .ndim
375+ if pp < 0 or pp + mask_nd > ary_nd :
376+ raise ValueError (
377+ "Parameter p is inconsistent with input array dimensions"
378+ )
379+ mask_nelems = ary_mask .size
380+ cumsum_dt = dpt .int32 if mask_nelems < int32_t_max else dpt .int64
381+ cumsum = dpt .empty (
382+ mask_nelems ,
383+ dtype = cumsum_dt ,
384+ usm_type = coerced_usm_type ,
385+ device = ary_mask .device ,
386+ )
387+ exec_q = cumsum .sycl_queue
388+ _manager = dpctl .utils .SequentialOrderManager [exec_q ]
389+ dep_ev = _manager .submitted_events
390+ mask_count = ti .mask_positions (
391+ ary_mask , cumsum , sycl_queue = exec_q , depends = dep_ev
392+ )
393+ expected_vals_shape = (
394+ ary .shape [:pp ] + (mask_count ,) + ary .shape [pp + mask_nd :]
395+ )
396+ if vals .dtype == ary .dtype :
397+ rhs = vals
398+ else :
399+ rhs = dpt .astype (vals , ary .dtype )
400+ rhs = dpt .broadcast_to (rhs , expected_vals_shape )
401+ if mask_nelems == 0 :
402+ return
403+ dep_ev = _manager .submitted_events
404+ hev , pl_ev = ti ._place (
405+ dst = ary ,
406+ cumsum = cumsum ,
407+ axis_start = pp ,
408+ axis_end = pp + mask_nd ,
409+ rhs = rhs ,
410+ sycl_queue = exec_q ,
411+ depends = dep_ev ,
412+ )
413+ _manager .add_event_pair (hev , pl_ev )
414+ return
415+
416+
313417def _put_multi_index (ary , inds , p , vals , mode = 0 ):
314418 if not isinstance (ary , dpt .usm_ndarray ):
315419 raise TypeError (
0 commit comments