@@ -238,6 +238,65 @@ class DPNPC_id final
238238 return output_size;
239239 }
240240
241+ /* *
242+ * @ingroup BACKEND_UTILS
243+ * @brief Broadcast input data to specified shape.
244+ *
245+ * Set output shape to use in computation of input index by output index.
246+ *
247+ * @note this function is designed for non-SYCL environment execution
248+ *
249+ * @param [in] __shape Output shape.
250+ */
251+ inline void broadcast_to_shape (const std::vector<size_type>& __shape)
252+ {
253+ if (axis_use)
254+ {
255+ return ;
256+ }
257+
258+ if (broadcastable (input_shape, input_shape_size, __shape))
259+ {
260+ free_broadcast_axes_memory ();
261+ free_output_memory ();
262+
263+ std::vector<size_type> valid_axes;
264+ broadcast_use = true ;
265+
266+ output_shape_size = __shape.size ();
267+ const size_type output_shape_size_in_bytes = output_shape_size * sizeof (size_type);
268+ output_shape = reinterpret_cast <size_type*>(dpnp_memory_alloc_c (output_shape_size_in_bytes));
269+
270+ for (int irit = input_shape_size - 1 , orit = output_shape_size - 1 ; orit >= 0 ; --irit, --orit)
271+ {
272+ output_shape[orit] = __shape[orit];
273+
274+ // ex: input_shape = {7, 1, 5}, output_shape = {8, 7, 6, 5} => valid_axes = {0, 2}
275+ if (irit < 0 || input_shape[irit] != output_shape[orit])
276+ {
277+ valid_axes.insert (valid_axes.begin (), orit);
278+ }
279+ }
280+
281+ broadcast_axes_size = valid_axes.size ();
282+ const size_type broadcast_axes_size_in_bytes = broadcast_axes_size * sizeof (size_type);
283+ broadcast_axes = reinterpret_cast <size_type*>(dpnp_memory_alloc_c (broadcast_axes_size_in_bytes));
284+ std::copy (valid_axes.begin (), valid_axes.end (), broadcast_axes);
285+
286+ output_size = std::accumulate (
287+ output_shape, output_shape + output_shape_size, size_type (1 ), std::multiplies<size_type>());
288+
289+ output_shape_strides = reinterpret_cast <size_type*>(dpnp_memory_alloc_c (output_shape_size_in_bytes));
290+ get_shape_offsets_inkernel<size_type>(output_shape, output_shape_size, output_shape_strides);
291+
292+ iteration_size = 1 ;
293+
294+ // make thread private storage for each shape by multiplying memory
295+ sycl_output_xyz =
296+ reinterpret_cast <size_type*>(dpnp_memory_alloc_c (output_size * output_shape_size_in_bytes));
297+ }
298+ }
299+
241300 /* *
242301 * @ingroup BACKEND_UTILS
243302 * @brief Set axis for the data object to use in computation.
@@ -285,6 +344,11 @@ class DPNPC_id final
285344 */
286345 inline void set_axes (const std::vector<long >& __axes)
287346 {
347+ if (broadcast_use)
348+ {
349+ return ;
350+ }
351+
288352 if (!__axes.empty () && input_shape_size)
289353 {
290354 free_axes_memory ();
@@ -368,6 +432,11 @@ class DPNPC_id final
368432 // / this function is designed for SYCL environment execution
369433 inline reference operator [](size_type __n) const
370434 {
435+ if (broadcast_use)
436+ {
437+ return *begin (__n);
438+ }
439+
371440 const iterator it = begin ();
372441 return it[__n];
373442 }
@@ -430,6 +499,24 @@ class DPNPC_id final
430499 }
431500 }
432501 }
502+ else if (broadcast_use)
503+ {
504+ assert (output_global_id < output_size);
505+
506+ // use thread private storage
507+ size_type* sycl_output_xyz_thread = sycl_output_xyz + (output_global_id * output_shape_size);
508+
509+ get_xyz_by_id_inkernel (output_global_id, output_shape_strides, output_shape_size, sycl_output_xyz_thread);
510+
511+ for (int irit = input_shape_size - 1 , orit = output_shape_size - 1 ; irit >= 0 ; --irit, --orit)
512+ {
513+ size_type* broadcast_axes_end = broadcast_axes + broadcast_axes_size;
514+ if (std::find (broadcast_axes, broadcast_axes_end, orit) == broadcast_axes_end)
515+ {
516+ input_global_id += (sycl_output_xyz_thread[orit] * input_shape_strides[irit]);
517+ }
518+ }
519+ }
433520
434521 return input_global_id;
435522 }
@@ -447,6 +534,13 @@ class DPNPC_id final
447534 axes_shape_strides = nullptr ;
448535 }
449536
537+ void free_broadcast_axes_memory ()
538+ {
539+ broadcast_axes_size = size_type{};
540+ dpnp_memory_free_c (broadcast_axes);
541+ broadcast_axes = nullptr ;
542+ }
543+
450544 void free_input_memory ()
451545 {
452546 input_size = size_type{};
@@ -480,6 +574,7 @@ class DPNPC_id final
480574 void free_memory ()
481575 {
482576 free_axes_memory ();
577+ free_broadcast_axes_memory ();
483578 free_input_memory ();
484579 free_iteration_memory ();
485580 free_output_memory ();
@@ -494,6 +589,10 @@ class DPNPC_id final
494589 std::vector<size_type> axes; /* *< input shape reduction axes */
495590 bool axis_use = false ;
496591
592+ size_type* broadcast_axes = nullptr ; /* *< input shape broadcast axes */
593+ size_type broadcast_axes_size = size_type{}; /* *< input shape broadcast axes size */
594+ bool broadcast_use = false ;
595+
497596 size_type output_size = size_type{}; /* *< output array size. Expected is same as GWS */
498597 size_type* output_shape = nullptr ; /* *< output array shape */
499598 size_type output_shape_size = size_type{}; /* *< output array shape size */
0 commit comments