4242
4343namespace py = pybind11;
4444
45- static dpctl::tensor::detail::usm_ndarray_types array_types;
46-
4745namespace
4846{
4947
@@ -301,6 +299,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
301299 int src_typenum = src.get_typenum ();
302300 int dst_typenum = dst.get_typenum ();
303301
302+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
304303 int src_type_id = array_types.typenum_to_lookup_id (src_typenum);
305304 int dst_type_id = array_types.typenum_to_lookup_id (dst_typenum);
306305
@@ -322,15 +321,16 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
322321 throw py::value_error (" Arrays index overlapping segments of memory" );
323322 }
324323
325- int src_flags = src.get_flags ();
326- int dst_flags = dst.get_flags ();
324+ bool is_src_c_contig = src.is_c_contiguous ();
325+ bool is_src_f_contig = src.is_f_contiguous ();
326+
327+ bool is_dst_c_contig = dst.is_c_contiguous ();
328+ bool is_dst_f_contig = dst.is_f_contiguous ();
327329
328330 // check for applicability of special cases:
329331 // (same type && (both C-contiguous || both F-contiguous)
330- bool both_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) &&
331- (dst_flags & USM_ARRAY_C_CONTIGUOUS));
332- bool both_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) &&
333- (dst_flags & USM_ARRAY_F_CONTIGUOUS));
332+ bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
333+ bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
334334 if (both_c_contig || both_f_contig) {
335335 if (src_type_id == dst_type_id) {
336336
@@ -360,12 +360,6 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
360360 int nd = src_nd;
361361 const py::ssize_t *shape = src_shape;
362362
363- bool is_src_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) != 0 );
364- bool is_src_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) != 0 );
365-
366- bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0 );
367- bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0 );
368-
369363 constexpr py::ssize_t src_itemsize = 1 ; // in elements
370364 constexpr py::ssize_t dst_itemsize = 1 ; // in elements
371365
@@ -550,6 +544,7 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
550544 const py::ssize_t *src_shape = src.get_shape_raw ();
551545 const py::ssize_t *dst_shape = dst.get_shape_raw ();
552546
547+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
553548 int type_id = array_types.typenum_to_lookup_id (src_typenum);
554549
555550 auto fn = copy_for_reshape_generic_dispatch_vector[type_id];
@@ -576,14 +571,13 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
576571
577572 const py::ssize_t *src_strides = src.get_strides_raw ();
578573 if (src_strides == nullptr ) {
579- int src_flags = src.get_flags ();
580- if (src_flags & USM_ARRAY_C_CONTIGUOUS) {
574+ if (src.is_c_contiguous ()) {
581575 const auto &src_contig_strides =
582576 c_contiguous_strides (src_nd, src_shape);
583577 std::copy (src_contig_strides.begin (), src_contig_strides.end (),
584578 packed_host_shapes_strides_shp->begin () + src_nd);
585579 }
586- else if (src_flags & USM_ARRAY_F_CONTIGUOUS ) {
580+ else if (src. is_f_contiguous () ) {
587581 const auto &src_contig_strides =
588582 f_contiguous_strides (src_nd, src_shape);
589583 std::copy (src_contig_strides.begin (), src_contig_strides.end (),
@@ -602,15 +596,14 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
602596
603597 const py::ssize_t *dst_strides = dst.get_strides_raw ();
604598 if (dst_strides == nullptr ) {
605- int dst_flags = dst.get_flags ();
606- if (dst_flags & USM_ARRAY_C_CONTIGUOUS) {
599+ if (dst.is_c_contiguous ()) {
607600 const auto &dst_contig_strides =
608601 c_contiguous_strides (dst_nd, dst_shape);
609602 std::copy (dst_contig_strides.begin (), dst_contig_strides.end (),
610603 packed_host_shapes_strides_shp->begin () + 2 * src_nd +
611604 dst_nd);
612605 }
613- else if (dst_flags & USM_ARRAY_F_CONTIGUOUS ) {
606+ else if (dst. is_f_contiguous () ) {
614607 const auto &dst_contig_strides =
615608 f_contiguous_strides (dst_nd, dst_shape);
616609 std::copy (dst_contig_strides.begin (), dst_contig_strides.end (),
@@ -736,6 +729,7 @@ void copy_numpy_ndarray_into_usm_ndarray(
736729 py::detail::array_descriptor_proxy (npy_src.dtype ().ptr ())->type_num ;
737730 int dst_typenum = dst.get_typenum ();
738731
732+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
739733 int src_type_id = array_types.typenum_to_lookup_id (src_typenum);
740734 int dst_type_id = array_types.typenum_to_lookup_id (dst_typenum);
741735
@@ -744,14 +738,13 @@ void copy_numpy_ndarray_into_usm_ndarray(
744738 char *dst_data = dst.get_data ();
745739
746740 int src_flags = npy_src.flags ();
747- int dst_flags = dst.get_flags ();
748741
749742 // check for applicability of special cases:
750743 // (same type && (both C-contiguous || both F-contiguous)
751- bool both_c_contig = ((src_flags & py::array::c_style) &&
752- (dst_flags & USM_ARRAY_C_CONTIGUOUS ));
753- bool both_f_contig = ((src_flags & py::array::f_style) &&
754- (dst_flags & USM_ARRAY_F_CONTIGUOUS ));
744+ bool both_c_contig =
745+ ((src_flags & py::array::c_style) && dst. is_c_contiguous ( ));
746+ bool both_f_contig =
747+ ((src_flags & py::array::f_style) && dst. is_f_contiguous ( ));
755748 if (both_c_contig || both_f_contig) {
756749 if (src_type_id == dst_type_id) {
757750 int src_elem_size = npy_src.itemsize ();
@@ -791,8 +784,8 @@ void copy_numpy_ndarray_into_usm_ndarray(
791784 bool is_src_c_contig = ((src_flags & py::array::c_style) != 0 );
792785 bool is_src_f_contig = ((src_flags & py::array::f_style) != 0 );
793786
794- bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0 );
795- bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0 );
787+ bool is_dst_c_contig = dst. is_c_contiguous ( );
788+ bool is_dst_f_contig = dst. is_f_contiguous ( );
796789
797790 // all args except itemsizes and is_?_contig bools can be modified by
798791 // reference
@@ -906,18 +899,18 @@ usm_ndarray_linear_sequence_step(py::object start,
906899 " usm_ndarray_linspace: Expecting 1D array to populate" );
907900 }
908901
909- int flags = dst.get_flags ();
910- if (!(flags & USM_ARRAY_C_CONTIGUOUS)) {
902+ if (!dst.is_c_contiguous ()) {
911903 throw py::value_error (
912904 " usm_ndarray_linspace: Non-contiguous arrays are not supported" );
913905 }
914906
915907 sycl::queue dst_q = dst.get_queue ();
916- if (dst_q != exec_q && dst_q. get_context () != exec_q. get_context ( )) {
908+ if (! dpctl::utils::queues_are_compatible ( exec_q, { dst_q} )) {
917909 throw py::value_error (
918- " Execution queue context is not the same as allocation context " );
910+ " Execution queue is not compatible with the allocation queue " );
919911 }
920912
913+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
921914 int dst_typenum = dst.get_typenum ();
922915 int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
923916
@@ -955,18 +948,18 @@ usm_ndarray_linear_sequence_affine(py::object start,
955948 " usm_ndarray_linspace: Expecting 1D array to populate" );
956949 }
957950
958- int flags = dst.get_flags ();
959- if (!(flags & USM_ARRAY_C_CONTIGUOUS)) {
951+ if (!dst.is_c_contiguous ()) {
960952 throw py::value_error (
961953 " usm_ndarray_linspace: Non-contiguous arrays are not supported" );
962954 }
963955
964956 sycl::queue dst_q = dst.get_queue ();
965- if (dst_q != exec_q && dst_q. get_context () != exec_q. get_context ( )) {
957+ if (! dpctl::utils::queues_are_compatible ( exec_q, { dst_q} )) {
966958 throw py::value_error (
967959 " Execution queue context is not the same as allocation context" );
968960 }
969961
962+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
970963 int dst_typenum = dst.get_typenum ();
971964 int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
972965
@@ -1010,23 +1003,20 @@ usm_ndarray_full(py::object py_value,
10101003 return std::make_pair (sycl::event (), sycl::event ());
10111004 }
10121005
1013- int dst_flags = dst.get_flags ();
1014-
10151006 sycl::queue dst_q = dst.get_queue ();
1016- if (dst_q != exec_q && dst_q. get_context () != exec_q. get_context ( )) {
1007+ if (! dpctl::utils::queues_are_compatible ( exec_q, { dst_q} )) {
10171008 throw py::value_error (
1018- " Execution queue context is not the same as allocation context " );
1009+ " Execution queue is not compatible with the allocation queue " );
10191010 }
10201011
1012+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
10211013 int dst_typenum = dst.get_typenum ();
10221014 int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
10231015
10241016 char *dst_data = dst.get_data ();
10251017 sycl::event full_event;
10261018
1027- if (dst_nelems == 1 || (dst_flags & USM_ARRAY_C_CONTIGUOUS) ||
1028- (dst_flags & USM_ARRAY_F_CONTIGUOUS))
1029- {
1019+ if (dst_nelems == 1 || dst.is_c_contiguous () || dst.is_f_contiguous ()) {
10301020 auto fn = full_contig_dispatch_vector[dst_typeid];
10311021
10321022 sycl::event full_contig_event =
@@ -1068,6 +1058,7 @@ eye(py::ssize_t k,
10681058 " allocation queue" );
10691059 }
10701060
1061+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
10711062 int dst_typenum = dst.get_typenum ();
10721063 int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
10731064
@@ -1079,8 +1070,8 @@ eye(py::ssize_t k,
10791070 return std::make_pair (sycl::event{}, sycl::event{});
10801071 }
10811072
1082- bool is_dst_c_contig = (( dst.get_flags () & USM_ARRAY_C_CONTIGUOUS) != 0 );
1083- bool is_dst_f_contig = (( dst.get_flags () & USM_ARRAY_F_CONTIGUOUS) != 0 );
1073+ bool is_dst_c_contig = dst.is_c_contiguous ( );
1074+ bool is_dst_f_contig = dst.is_f_contiguous ( );
10841075 if (!is_dst_c_contig && !is_dst_f_contig) {
10851076 throw py::value_error (" USM array is not contiguous" );
10861077 }
@@ -1182,6 +1173,8 @@ tri(sycl::queue &exec_q,
11821173 throw py::value_error (" Arrays index overlapping segments of memory" );
11831174 }
11841175
1176+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
1177+
11851178 int src_typenum = src.get_typenum ();
11861179 int dst_typenum = dst.get_typenum ();
11871180 int src_typeid = array_types.typenum_to_lookup_id (src_typenum);
@@ -1203,9 +1196,8 @@ tri(sycl::queue &exec_q,
12031196 using shT = std::vector<py::ssize_t >;
12041197 shT src_strides (src_nd);
12051198
1206- int src_flags = src.get_flags ();
1207- bool is_src_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) != 0 );
1208- bool is_src_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) != 0 );
1199+ bool is_src_c_contig = src.is_c_contiguous ();
1200+ bool is_src_f_contig = src.is_f_contiguous ();
12091201
12101202 const py::ssize_t *src_strides_raw = src.get_strides_raw ();
12111203 if (src_strides_raw == nullptr ) {
@@ -1227,9 +1219,8 @@ tri(sycl::queue &exec_q,
12271219
12281220 shT dst_strides (src_nd);
12291221
1230- int dst_flags = dst.get_flags ();
1231- bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0 );
1232- bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0 );
1222+ bool is_dst_c_contig = dst.is_c_contiguous ();
1223+ bool is_dst_f_contig = dst.is_f_contiguous ();
12331224
12341225 const py::ssize_t *dst_strides_raw = dst.get_strides_raw ();
12351226 if (dst_strides_raw == nullptr ) {
@@ -1457,9 +1448,6 @@ PYBIND11_MODULE(_tensor_impl, m)
14571448 init_copy_for_reshape_dispatch_vector ();
14581449 import_dpctl ();
14591450
1460- // populate types constants for type dispatching functions
1461- array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
1462-
14631451 m.def (
14641452 " _contract_iter" , &contract_iter,
14651453 " Simplifies iteration of array of given shape & stride. Returns "
0 commit comments