@@ -52,15 +52,12 @@ namespace py_internal
5252namespace _ns = dpctl::tensor::detail;
5353
5454using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_1d_fn_ptr_t ;
55- using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_2d_fn_ptr_t ;
5655using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_generic_fn_ptr_t ;
5756
5857static copy_and_cast_generic_fn_ptr_t
5958 copy_and_cast_generic_dispatch_table[_ns::num_types][_ns::num_types];
6059static copy_and_cast_1d_fn_ptr_t
6160 copy_and_cast_1d_dispatch_table[_ns::num_types][_ns::num_types];
62- static copy_and_cast_2d_fn_ptr_t
63- copy_and_cast_2d_dispatch_table[_ns::num_types][_ns::num_types];
6461
6562namespace py = pybind11;
6663
@@ -187,7 +184,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
187184 simplified_shape, simplified_src_strides, simplified_dst_strides,
188185 src_offset, dst_offset);
189186
190- if (nd < 3 ) {
187+ if (nd < 2 ) {
191188 if (nd == 1 ) {
192189 std::array<py::ssize_t , 1 > shape_arr = {shape[0 ]};
193190 // strides may be null
@@ -205,23 +202,6 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
205202 keep_args_alive (exec_q, {src, dst}, {copy_and_cast_1d_event}),
206203 copy_and_cast_1d_event);
207204 }
208- else if (nd == 2 ) {
209- std::array<py::ssize_t , 2 > shape_arr = {shape[0 ], shape[1 ]};
210- std::array<py::ssize_t , 2 > src_strides_arr = {src_strides[0 ],
211- src_strides[1 ]};
212- std::array<py::ssize_t , 2 > dst_strides_arr = {dst_strides[0 ],
213- dst_strides[1 ]};
214-
215- auto fn = copy_and_cast_2d_dispatch_table[dst_type_id][src_type_id];
216-
217- sycl::event copy_and_cast_2d_event = fn (
218- exec_q, src_nelems, shape_arr, src_strides_arr, dst_strides_arr,
219- src_data, src_offset, dst_data, dst_offset, depends);
220-
221- return std::make_pair (
222- keep_args_alive (exec_q, {src, dst}, {copy_and_cast_2d_event}),
223- copy_and_cast_2d_event);
224- }
225205 else if (nd == 0 ) { // case of a scalar
226206 assert (src_nelems == 1 );
227207 std::array<py::ssize_t , 1 > shape_arr = {1 };
@@ -290,12 +270,6 @@ void init_copy_and_cast_usm_to_usm_dispatch_tables(void)
290270 num_types>
291271 dtb_1d;
292272 dtb_1d.populate_dispatch_table (copy_and_cast_1d_dispatch_table);
293-
294- using dpctl::tensor::kernels::copy_and_cast::CopyAndCast2DFactory;
295- DispatchTableBuilder<copy_and_cast_2d_fn_ptr_t , CopyAndCast2DFactory,
296- num_types>
297- dtb_2d;
298- dtb_2d.populate_dispatch_table (copy_and_cast_2d_dispatch_table);
299273}
300274
301275} // namespace py_internal
0 commit comments