2323// ===----------------------------------------------------------------------===//
2424
2525#include < CL/sycl.hpp>
26+ #include < algorithm>
2627#include < complex>
2728#include < cstdint>
2829#include < pybind11/complex.h>
@@ -663,12 +664,6 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
663664 }
664665 }
665666
666- std::shared_ptr<shT> shp_shape = std::make_shared<shT>(simplified_shape);
667- std::shared_ptr<shT> shp_src_strides =
668- std::make_shared<shT>(simplified_src_strides);
669- std::shared_ptr<shT> shp_dst_strides =
670- std::make_shared<shT>(simplified_dst_strides);
671-
672667 // Generic implementation
673668 auto copy_and_cast_fn =
674669 copy_and_cast_generic_dispatch_table[dst_type_id][src_type_id];
@@ -682,77 +677,50 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
682677 throw std::runtime_error (" Unabled to allocate device memory" );
683678 }
684679
685- sycl::event copy_shape_ev =
686- exec_q.copy <py::ssize_t >(shp_shape->data (), shape_strides, nd);
680+ // create host temporary for packed shape and strides managed by shared
681+ // pointer
682+ std::shared_ptr<shT> shp_host_shape_strides = std::make_shared<shT>(3 * nd);
683+ std::copy (simplified_shape.begin (), simplified_shape.end (),
684+ shp_host_shape_strides->begin ());
687685
688- exec_q.submit ([&](sycl::handler &cgh) {
689- cgh.depends_on (copy_shape_ev);
690- cgh.host_task ([shp_shape]() {
691- // increment shared pointer ref-count to keep it alive
692- // till copy operation completes;
693- });
694- });
695-
696- sycl::event copy_src_strides_ev;
697686 if (src_strides == nullptr ) {
698- std::shared_ptr<shT> shp_contig_src_strides =
699- std::make_shared<shT>((src_flags & USM_ARRAY_C_CONTIGUOUS)
700- ? c_contiguous_strides (nd, shape)
701- : f_contiguous_strides (nd, shape));
702- copy_src_strides_ev = exec_q.copy <py::ssize_t >(
703- shp_contig_src_strides->data (), shape_strides + nd, nd);
704- exec_q.submit ([&](sycl::handler &cgh) {
705- cgh.depends_on (copy_src_strides_ev);
706- cgh.host_task ([shp_contig_src_strides]() {
707- // increment shared pointer ref-count to keep it alive
708- // till copy operation completes;
709- });
710- });
687+ const shT &src_contig_strides = (src_flags & USM_ARRAY_C_CONTIGUOUS)
688+ ? c_contiguous_strides (nd, shape)
689+ : f_contiguous_strides (nd, shape);
690+ std::copy (src_contig_strides.begin (), src_contig_strides.end (),
691+ shp_host_shape_strides->begin () + nd);
711692 }
712693 else {
713- copy_src_strides_ev = exec_q.copy <py::ssize_t >(shp_src_strides->data (),
714- shape_strides + nd, nd);
715- exec_q.submit ([&](sycl::handler &cgh) {
716- cgh.depends_on (copy_src_strides_ev);
717- cgh.host_task ([shp_src_strides]() {
718- // increment shared pointer ref-count to keep it alive
719- // till copy operation completes;
720- });
721- });
694+ std::copy (simplified_src_strides.begin (), simplified_src_strides.end (),
695+ shp_host_shape_strides->begin () + nd);
722696 }
723697
724- sycl::event copy_dst_strides_ev;
725698 if (dst_strides == nullptr ) {
726- std::shared_ptr<shT> shp_contig_dst_strides =
727- std::make_shared<shT>((dst_flags & USM_ARRAY_C_CONTIGUOUS)
728- ? c_contiguous_strides (nd, shape)
729- : f_contiguous_strides (nd, shape));
730- copy_dst_strides_ev = exec_q.copy <py::ssize_t >(
731- shp_contig_dst_strides->data (), shape_strides + 2 * nd, nd);
732- exec_q.submit ([&](sycl::handler &cgh) {
733- cgh.depends_on (copy_dst_strides_ev);
734- cgh.host_task ([shp_contig_dst_strides]() {
735- // increment shared pointer ref-count to keep it alive
736- // till copy operation completes;
737- });
738- });
699+ const shT &dst_contig_strides = (src_flags & USM_ARRAY_C_CONTIGUOUS)
700+ ? c_contiguous_strides (nd, shape)
701+ : f_contiguous_strides (nd, shape);
702+ std::copy (dst_contig_strides.begin (), dst_contig_strides.end (),
703+ shp_host_shape_strides->begin () + 2 * nd);
739704 }
740705 else {
741- copy_dst_strides_ev = exec_q.copy <py::ssize_t >(
742- shp_dst_strides->data (), shape_strides + 2 * nd, nd);
743- exec_q.submit ([&](sycl::handler &cgh) {
744- cgh.depends_on (copy_dst_strides_ev);
745- cgh.host_task ([shp_dst_strides]() {
746- // increment shared pointer ref-count to keep it alive
747- // till copy operation completes;
748- });
749- });
706+ std::copy (simplified_dst_strides.begin (), simplified_dst_strides.end (),
707+ shp_host_shape_strides->begin () + nd);
750708 }
751709
710+ sycl::event copy_shape_ev = exec_q.copy <py::ssize_t >(
711+ shp_host_shape_strides->data (), shape_strides, 3 * nd);
712+
713+ exec_q.submit ([&](sycl::handler &cgh) {
714+ cgh.depends_on (copy_shape_ev);
715+ cgh.host_task ([shp_host_shape_strides]() {
716+ // increment shared pointer ref-count to keep it alive
717+ // till copy operation completes;
718+ });
719+ });
720+
752721 sycl::event copy_and_cast_generic_ev = copy_and_cast_fn (
753722 exec_q, src_nelems, nd, shape_strides, src_data, src_offset, dst_data,
754- dst_offset, depends,
755- {copy_shape_ev, copy_src_strides_ev, copy_dst_strides_ev});
723+ dst_offset, depends, {copy_shape_ev});
756724
757725 // async free of shape_strides temporary
758726 auto ctx = exec_q.get_context ();
0 commit comments