@@ -38,9 +38,10 @@ class dpnp_elemwise_transpose_c_kernel;
3838
3939template <typename _DataType>
4040void dpnp_elemwise_transpose_c (void * array1_in,
41- const std::vector<long >& input_shape,
42- const std::vector<long >& result_shape,
43- const std::vector<long >& permute_axes,
41+ const size_t * input_shape,
42+ const size_t * result_shape,
43+ const size_t * permute_axes,
44+ size_t ndim,
4445 void * result1,
4546 size_t size)
4647{
@@ -53,24 +54,16 @@ void dpnp_elemwise_transpose_c(void* array1_in,
5354 _DataType* array1 = reinterpret_cast <_DataType*>(array1_in);
5455 _DataType* result = reinterpret_cast <_DataType*>(result1);
5556
56- const size_t input_shape_size = input_shape.size ();
57- size_t * input_offset_shape = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (input_shape_size * sizeof (long )));
58- size_t * result_offset_shape = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (input_shape_size * sizeof (long )));
57+ size_t * input_offset_shape = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (ndim * sizeof (long )));
58+ get_shape_offsets_inkernel (input_shape, ndim, input_offset_shape);
5959
60- size_t dim_prod_input = 1 ;
61- size_t dim_prod_result = 1 ;
62- for (long i = input_shape_size - 1 ; i >= 0 ; --i)
60+ size_t * temp_result_offset_shape = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (ndim * sizeof (long )));
61+ get_shape_offsets_inkernel (result_shape, ndim, temp_result_offset_shape);
62+
63+ size_t * result_offset_shape = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (ndim * sizeof (long )));
64+ for (size_t axis = 0 ; axis < ndim; ++axis)
6365 {
64- /*
65- for example above, offset vectors will be
66- input_offset_shape=[12, 4, 1]
67- result_offset_shape=[1, 2, 6]
68- */
69- input_offset_shape[i] = dim_prod_input;
70- result_offset_shape[permute_axes[i]] = dim_prod_result;
71-
72- dim_prod_input *= input_shape[i];
73- dim_prod_result *= result_shape[i];
66+ result_offset_shape[permute_axes[axis]] = temp_result_offset_shape[axis];
7467 }
7568
7669 cl::sycl::range<1 > gws (size);
@@ -79,7 +72,7 @@ void dpnp_elemwise_transpose_c(void* array1_in,
7972
8073 size_t output_index = 0 ;
8174 size_t reminder = idx;
82- for (size_t axis = 0 ; axis < input_shape_size ; ++axis)
75+ for (size_t axis = 0 ; axis < ndim ; ++axis)
8376 {
8477 /* reconstruct [x][y][z] from given linear idx */
8578 size_t xyz_id = reminder / input_offset_shape[axis];
@@ -100,8 +93,9 @@ void dpnp_elemwise_transpose_c(void* array1_in,
10093
10194 event.wait ();
10295
103- free (input_offset_shape, DPNP_QUEUE);
104- free (result_offset_shape, DPNP_QUEUE);
96+ dpnp_memory_free_c (input_offset_shape);
97+ dpnp_memory_free_c (temp_result_offset_shape);
98+ dpnp_memory_free_c (result_offset_shape);
10599}
106100
107101void func_map_init_manipulation (func_map_t & fmap)
0 commit comments