@@ -932,99 +932,96 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
932932
933933 auto fn = copy_for_reshape_generic_dispatch_vector[type_id];
934934
935+ // packed_shape_strides = [src_shape, src_strides, dst_shape, dst_strides]
935936 py::ssize_t *packed_shapes_strides =
936937 sycl::malloc_device<py::ssize_t >(2 * (src_nd + dst_nd), exec_q);
937938
938939 if (packed_shapes_strides == nullptr ) {
939940 throw std::runtime_error (" Unabled to allocate device memory" );
940941 }
941942
942- sycl::event src_shape_copy_ev =
943- exec_q.copy <py::ssize_t >(src_shape, packed_shapes_strides, src_nd);
944- sycl::event dst_shape_copy_ev = exec_q.copy <py::ssize_t >(
945- dst_shape, packed_shapes_strides + 2 * src_nd, dst_nd);
943+ using shT = std::vector<py::ssize_t >;
944+ std::shared_ptr<shT> packed_host_shapes_strides_shp =
945+ std::make_shared<shT>(2 * (src_nd + dst_nd));
946+
947+ std::copy (src_shape, src_shape + src_nd,
948+ packed_host_shapes_strides_shp->begin ());
949+ std::copy (dst_shape, dst_shape + dst_nd,
950+ packed_host_shapes_strides_shp->begin () + 2 * src_nd);
946951
947952 const py::ssize_t *src_strides = src.get_strides_raw ();
948- sycl::event src_strides_copy_ev;
949953 if (src_strides == nullptr ) {
950- using shT = std::vector<py::ssize_t >;
951954 int src_flags = src.get_flags ();
952- std::shared_ptr<shT> contig_src_strides_shp;
953955 if (src_flags & USM_ARRAY_C_CONTIGUOUS) {
954- contig_src_strides_shp =
955- std::make_shared<shT>(c_contiguous_strides (src_nd, src_shape));
956+ const shT &src_contig_strides =
957+ c_contiguous_strides (src_nd, src_shape);
958+ std::copy (src_contig_strides.begin (), src_contig_strides.end (),
959+ packed_host_shapes_strides_shp->begin () + src_nd);
956960 }
957961 else if (src_flags & USM_ARRAY_F_CONTIGUOUS) {
958- contig_src_strides_shp =
959- std::make_shared<shT>(f_contiguous_strides (src_nd, src_shape));
962+ const shT &src_contig_strides =
963+ c_contiguous_strides (src_nd, src_shape);
964+ std::copy (src_contig_strides.begin (), src_contig_strides.end (),
965+ packed_host_shapes_strides_shp->begin () + src_nd);
960966 }
961967 else {
962- sycl::event::wait ({src_shape_copy_ev, dst_shape_copy_ev});
963968 sycl::free (packed_shapes_strides, exec_q);
964969 throw std::runtime_error (
965970 " Invalid src array encountered: in copy_for_reshape function" );
966971 }
967- src_strides_copy_ev =
968- exec_q.copy <py::ssize_t >(contig_src_strides_shp->data (),
969- packed_shapes_strides + src_nd, src_nd);
970- exec_q.submit ([&](sycl::handler &cgh) {
971- cgh.depends_on (src_strides_copy_ev);
972- cgh.host_task ([contig_src_strides_shp]() {
973- // Capturing shared pointer ensure it is freed after its data
974- // are copied into packed USM vector
975- });
976- });
977972 }
978973 else {
979- src_strides_copy_ev = exec_q. copy <py:: ssize_t >(
980- src_strides, packed_shapes_strides + src_nd, src_nd);
974+ std::copy (src_strides, src_strides + src_nd,
975+ packed_host_shapes_strides_shp-> begin () + src_nd);
981976 }
982977
983978 const py::ssize_t *dst_strides = dst.get_strides_raw ();
984- sycl::event dst_strides_copy_ev;
985979 if (dst_strides == nullptr ) {
986- using shT = std::vector<py::ssize_t >;
987980 int dst_flags = dst.get_flags ();
988- std::shared_ptr<shT> contig_dst_strides_shp;
989981 if (dst_flags & USM_ARRAY_C_CONTIGUOUS) {
990- contig_dst_strides_shp =
991- std::make_shared<shT>(c_contiguous_strides (dst_nd, dst_shape));
982+ const shT &dst_contig_strides =
983+ c_contiguous_strides (dst_nd, dst_shape);
984+ std::copy (dst_contig_strides.begin (), dst_contig_strides.end (),
985+ packed_host_shapes_strides_shp->begin () + 2 * src_nd +
986+ dst_nd);
992987 }
993988 else if (dst_flags & USM_ARRAY_F_CONTIGUOUS) {
994- contig_dst_strides_shp =
995- std::make_shared<shT>(f_contiguous_strides (dst_nd, dst_shape));
989+ const shT &dst_contig_strides =
990+ f_contiguous_strides (dst_nd, dst_shape);
991+ std::copy (dst_contig_strides.begin (), dst_contig_strides.end (),
992+ packed_host_shapes_strides_shp->begin () + 2 * src_nd +
993+ dst_nd);
996994 }
997995 else {
998- sycl::event::wait (
999- {src_shape_copy_ev, dst_shape_copy_ev, src_strides_copy_ev});
1000996 sycl::free (packed_shapes_strides, exec_q);
1001997 throw std::runtime_error (
1002998 " Invalid dst array encountered: in copy_for_reshape function" );
1003999 }
1004- dst_strides_copy_ev = exec_q.copy <py::ssize_t >(
1005- contig_dst_strides_shp->data (),
1006- packed_shapes_strides + 2 * src_nd + dst_nd, dst_nd);
1007- exec_q.submit ([&](sycl::handler &cgh) {
1008- cgh.depends_on (dst_strides_copy_ev);
1009- cgh.host_task ([contig_dst_strides_shp]() {
1010- // Capturing shared pointer ensure it is freed after its data
1011- // are copied into packed USM vector
1012- });
1013- });
10141000 }
10151001 else {
1016- dst_strides_copy_ev = exec_q.copy <py::ssize_t >(
1017- dst_strides, packed_shapes_strides + 2 * src_nd + dst_nd, dst_nd);
1002+ std::copy (dst_strides, dst_strides + dst_nd,
1003+ packed_host_shapes_strides_shp->begin () + 2 * src_nd +
1004+ dst_nd);
10181005 }
10191006
1007+ // copy packed shapes and strides from host to devices
1008+ sycl::event packed_shape_strides_copy_ev = exec_q.copy <py::ssize_t >(
1009+ packed_host_shapes_strides_shp->data (), packed_shapes_strides,
1010+ packed_host_shapes_strides_shp->size ());
1011+ exec_q.submit ([&](sycl::handler &cgh) {
1012+ cgh.depends_on (packed_shape_strides_copy_ev);
1013+ cgh.host_task ([packed_host_shapes_strides_shp] {
1014+ // Capturing shared pointer ensures that the underlying vector is
1015+ // not destroyed until after its data are copied into packed USM
1016+ // vector
1017+ });
1018+ });
1019+
10201020 char *src_data = src.get_data ();
10211021 char *dst_data = dst.get_data ();
10221022
1023- std::vector<sycl::event> all_deps (depends.size () + 4 );
1024- all_deps.push_back (src_shape_copy_ev);
1025- all_deps.push_back (dst_shape_copy_ev);
1026- all_deps.push_back (src_strides_copy_ev);
1027- all_deps.push_back (dst_strides_copy_ev);
1023+ std::vector<sycl::event> all_deps (depends.size () + 1 );
1024+ all_deps.push_back (packed_shape_strides_copy_ev);
10281025 all_deps.insert (std::end (all_deps), std::begin (depends), std::end (depends));
10291026
10301027 sycl::event copy_for_reshape_event =
0 commit comments