Skip to content

Commit 6b0c990

Browse files
Transpose refactoring (#574)
* replace vectors by pointers in transpose interface
1 parent 4cfe1a6 commit 6b0c990

3 files changed

Lines changed: 25 additions & 28 deletions

File tree

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,14 +548,16 @@ INP_DLLEXPORT void dpnp_remainder_c(void* array1_in, void* array2_in, void* resu
548548
* @param [in] input_shape Input shape.
549549
* @param [in] result_shape Output shape.
550550
* @param [in] permute_axes Order of axis by it's id as it should be presented in output.
551+
* @param [in] ndim Number of elements in shapes and axes.
551552
* @param [out] result1 Output array.
552553
* @param [in] size Number of elements in input arrays.
553554
*/
554555
template <typename _DataType>
555556
INP_DLLEXPORT void dpnp_elemwise_transpose_c(void* array1_in,
556-
const std::vector<long>& input_shape,
557-
const std::vector<long>& result_shape,
558-
const std::vector<long>& permute_axes,
557+
const size_t* input_shape,
558+
const size_t* result_shape,
559+
const size_t* permute_axes,
560+
size_t ndim,
559561
void* result1,
560562
size_t size);
561563

dpnp/backend/kernels/dpnp_krnl_manipulation.cpp

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ class dpnp_elemwise_transpose_c_kernel;
3838

3939
template <typename _DataType>
4040
void dpnp_elemwise_transpose_c(void* array1_in,
41-
const std::vector<long>& input_shape,
42-
const std::vector<long>& result_shape,
43-
const std::vector<long>& permute_axes,
41+
const size_t* input_shape,
42+
const size_t* result_shape,
43+
const size_t* permute_axes,
44+
size_t ndim,
4445
void* result1,
4546
size_t size)
4647
{
@@ -53,24 +54,16 @@ void dpnp_elemwise_transpose_c(void* array1_in,
5354
_DataType* array1 = reinterpret_cast<_DataType*>(array1_in);
5455
_DataType* result = reinterpret_cast<_DataType*>(result1);
5556

56-
const size_t input_shape_size = input_shape.size();
57-
size_t* input_offset_shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(input_shape_size * sizeof(long)));
58-
size_t* result_offset_shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(input_shape_size * sizeof(long)));
57+
size_t* input_offset_shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(long)));
58+
get_shape_offsets_inkernel(input_shape, ndim, input_offset_shape);
5959

60-
size_t dim_prod_input = 1;
61-
size_t dim_prod_result = 1;
62-
for (long i = input_shape_size - 1; i >= 0; --i)
60+
size_t* temp_result_offset_shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(long)));
61+
get_shape_offsets_inkernel(result_shape, ndim, temp_result_offset_shape);
62+
63+
size_t* result_offset_shape = reinterpret_cast<size_t*>(dpnp_memory_alloc_c(ndim * sizeof(long)));
64+
for (size_t axis = 0; axis < ndim; ++axis)
6365
{
64-
/*
65-
for example above, offset vectors will be
66-
input_offset_shape=[12, 4, 1]
67-
result_offset_shape=[1, 2, 6]
68-
*/
69-
input_offset_shape[i] = dim_prod_input;
70-
result_offset_shape[permute_axes[i]] = dim_prod_result;
71-
72-
dim_prod_input *= input_shape[i];
73-
dim_prod_result *= result_shape[i];
66+
result_offset_shape[permute_axes[axis]] = temp_result_offset_shape[axis];
7467
}
7568

7669
cl::sycl::range<1> gws(size);
@@ -79,7 +72,7 @@ void dpnp_elemwise_transpose_c(void* array1_in,
7972

8073
size_t output_index = 0;
8174
size_t reminder = idx;
82-
for (size_t axis = 0; axis < input_shape_size; ++axis)
75+
for (size_t axis = 0; axis < ndim; ++axis)
8376
{
8477
/* reconstruct [x][y][z] from given linear idx */
8578
size_t xyz_id = reminder / input_offset_shape[axis];
@@ -100,8 +93,9 @@ void dpnp_elemwise_transpose_c(void* array1_in,
10093

10194
event.wait();
10295

103-
free(input_offset_shape, DPNP_QUEUE);
104-
free(result_offset_shape, DPNP_QUEUE);
96+
dpnp_memory_free_c(input_offset_shape);
97+
dpnp_memory_free_c(temp_result_offset_shape);
98+
dpnp_memory_free_c(result_offset_shape);
10599
}
106100

107101
void func_map_init_manipulation(func_map_t& fmap)

dpnp/dpnp_algo/dpnp_algo_manipulation.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ __all__ += [
4848

4949

5050
# C function pointer to the C library template functions
51-
ctypedef void(*fptr_custom_elemwise_transpose_1in_1out_t)(void * , dparray_shape_type & , dparray_shape_type & ,
52-
dparray_shape_type &, void * , size_t)
51+
ctypedef void(*fptr_custom_elemwise_transpose_1in_1out_t)(void * , size_t * , size_t * ,
52+
size_t * , size_t, void * , size_t)
5353

5454

5555
cpdef dparray dpnp_atleast_2d(dparray arr):
@@ -171,7 +171,8 @@ cpdef dparray dpnp_transpose(dparray array1, axes=None):
171171

172172
cdef fptr_custom_elemwise_transpose_1in_1out_t func = <fptr_custom_elemwise_transpose_1in_1out_t > kernel_data.ptr
173173
# call FPTR function
174-
func(array1.get_data(), input_shape, result_shape, permute_axes, result.get_data(), array1.size)
174+
func(array1.get_data(), < size_t * > input_shape.data(), < size_t * > result_shape.data(),
175+
< size_t * > permute_axes.data(), input_shape_size, result.get_data(), array1.size)
175176

176177
return result
177178

0 commit comments

Comments
 (0)