@@ -479,6 +479,46 @@ void simplify_iteration_space(int &nd,
479479 }
480480}
481481
482+ sycl::event _populate_packed_shape_strides_for_copycast_kernel (
483+ sycl::queue exec_q,
484+ int src_flags,
485+ int dst_flags,
486+ py::ssize_t *device_shape_strides, // to be populated
487+ const std::vector<py::ssize_t > &common_shape,
488+ const std::vector<py::ssize_t > &src_strides,
489+ const std::vector<py::ssize_t > &dst_strides)
490+ {
491+ using shT = std::vector<py::ssize_t >;
492+ size_t nd = common_shape.size ();
493+
494+ // create host temporary for packed shape and strides managed by shared
495+ // pointer. Packed vector is concatenation of common_shape, src_stride and
496+ // std_strides
497+ std::shared_ptr<shT> shp_host_shape_strides = std::make_shared<shT>(3 * nd);
498+ std::copy (common_shape.begin (), common_shape.end (),
499+ shp_host_shape_strides->begin ());
500+
501+ std::copy (src_strides.begin (), src_strides.end (),
502+ shp_host_shape_strides->begin () + nd);
503+
504+ std::copy (dst_strides.begin (), dst_strides.end (),
505+ shp_host_shape_strides->begin () + 2 * nd);
506+
507+ sycl::event copy_shape_ev = exec_q.copy <py::ssize_t >(
508+ shp_host_shape_strides->data (), device_shape_strides,
509+ shp_host_shape_strides->size ());
510+
511+ exec_q.submit ([&](sycl::handler &cgh) {
512+ cgh.depends_on (copy_shape_ev);
513+ cgh.host_task ([shp_host_shape_strides]() {
514+ // increment shared pointer ref-count to keep it alive
515+ // till copy operation completes;
516+ });
517+ });
518+
519+ return copy_shape_ev;
520+ }
521+
482522std::pair<sycl::event, sycl::event>
483523copy_usm_ndarray_into_usm_ndarray (dpctl::tensor::usm_ndarray src,
484524 dpctl::tensor::usm_ndarray dst,
@@ -677,47 +717,10 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
677717 throw std::runtime_error (" Unabled to allocate device memory" );
678718 }
679719
680- // create host temporary for packed shape and strides managed by shared
681- // pointer
682- std::shared_ptr<shT> shp_host_shape_strides = std::make_shared<shT>(3 * nd);
683- std::copy (simplified_shape.begin (), simplified_shape.end (),
684- shp_host_shape_strides->begin ());
685-
686- if (src_strides == nullptr ) {
687- const shT &src_contig_strides = (src_flags & USM_ARRAY_C_CONTIGUOUS)
688- ? c_contiguous_strides (nd, shape)
689- : f_contiguous_strides (nd, shape);
690- std::copy (src_contig_strides.begin (), src_contig_strides.end (),
691- shp_host_shape_strides->begin () + nd);
692- }
693- else {
694- std::copy (simplified_src_strides.begin (), simplified_src_strides.end (),
695- shp_host_shape_strides->begin () + nd);
696- }
697-
698- if (dst_strides == nullptr ) {
699- const shT &dst_contig_strides = (src_flags & USM_ARRAY_C_CONTIGUOUS)
700- ? c_contiguous_strides (nd, shape)
701- : f_contiguous_strides (nd, shape);
702- std::copy (dst_contig_strides.begin (), dst_contig_strides.end (),
703- shp_host_shape_strides->begin () + 2 * nd);
704- }
705- else {
706- std::copy (simplified_dst_strides.begin (), simplified_dst_strides.end (),
707- shp_host_shape_strides->begin () + nd);
708- }
709-
710720 sycl::event copy_shape_ev =
711- exec_q.copy <py::ssize_t >(shp_host_shape_strides->data (), shape_strides,
712- shp_host_shape_strides->size ());
713-
714- exec_q.submit ([&](sycl::handler &cgh) {
715- cgh.depends_on (copy_shape_ev);
716- cgh.host_task ([shp_host_shape_strides]() {
717- // increment shared pointer ref-count to keep it alive
718- // till copy operation completes;
719- });
720- });
721+ _populate_packed_shape_strides_for_copycast_kernel (
722+ exec_q, src_flags, dst_flags, shape_strides, simplified_shape,
723+ simplified_src_strides, simplified_dst_strides);
721724
722725 sycl::event copy_and_cast_generic_ev = copy_and_cast_fn (
723726 exec_q, src_nelems, nd, shape_strides, src_data, src_offset, dst_data,
0 commit comments