@@ -87,54 +87,92 @@ template <typename _DataType>
8787class dpnp_partition_c_kernel ;
8888
8989template <typename _DataType>
90- void dpnp_partition_c (const void * sort_array1_in , void * result1, const size_t kth, const size_t * shape , const size_t ndim)
90+ void dpnp_partition_c (void * array1_in, void * array2_in , void * result1, const size_t kth, const size_t * shape_ , const size_t ndim)
9191{
92+ _DataType* arr = reinterpret_cast <_DataType*>(array1_in);
93+ _DataType* arr2 = reinterpret_cast <_DataType*>(array2_in);
94+ _DataType* result = reinterpret_cast <_DataType*>(result1);
9295
93- cl::sycl::event event;
96+ if ((arr == nullptr ) || (result == nullptr ))
97+ {
98+ return ;
99+ }
94100
95- const _DataType* sort_arr = reinterpret_cast <const _DataType*>(sort_array1_in);
96- _DataType* result = reinterpret_cast <_DataType*>(result1);
101+ if (ndim < 1 )
102+ {
103+ return ;
104+ }
97105
98- size_t size_ = 1 ;
99- for (size_t i = 0 ; i < ndim - 1 ; ++i)
106+ size_t size = 1 ;
107+ for (size_t i = 0 ; i < ndim; ++i)
100108 {
101- size_ *= shape [i];
109+ size *= shape_ [i];
102110 }
103111
112+ size_t size_ = size/shape_[ndim-1 ];
113+
104114 if (size_ == 0 )
105115 {
106116 return ;
107117 }
108118
119+ auto arr_to_result_event = DPNP_QUEUE.memcpy (result, arr, size * sizeof (_DataType));
120+ arr_to_result_event.wait ();
121+
122+ for (size_t i = 0 ; i < size_; ++i)
123+ {
124+ size_t ind_begin = i * shape_[ndim-1 ];
125+ size_t ind_end = (i + 1 ) * shape_[ndim-1 ] - 1 ;
126+
127+ _DataType matrix[shape_[ndim-1 ]];
128+ for (size_t j = ind_begin; j < ind_end + 1 ; ++j)
129+ {
130+ size_t ind = j - ind_begin;
131+ matrix[ind] = arr2[j];
132+ }
133+ std::partial_sort (matrix, matrix + shape_[ndim-1 ], matrix + shape_[ndim-1 ]);
134+ for (size_t j = ind_begin; j < ind_end + 1 ; ++j)
135+ {
136+ size_t ind = j - ind_begin;
137+ arr2[j] = matrix[ind];
138+ }
139+ }
140+
141+ size_t * shape = reinterpret_cast <size_t *>(dpnp_memory_alloc_c (ndim * sizeof (size_t )));
142+ auto memcpy_event = DPNP_QUEUE.memcpy (shape, shape_, ndim * sizeof (size_t ));
143+
144+ memcpy_event.wait ();
145+
109146 cl::sycl::range<2 > gws (size_, kth+1 );
110147 auto kernel_parallel_for_func = [=](cl::sycl::id<2 > global_id) {
111148 size_t j = global_id[0 ];
112149 size_t k = global_id[1 ];
113150
114- _DataType val = sort_arr [j * shape[ndim - 1 ] + k];
151+ _DataType val = arr2 [j * shape[ndim - 1 ] + k];
115152
116- size_t ind = j * shape[ndim - 1 ] + k;
117153 for (size_t i = 0 ; i < shape[ndim - 1 ]; ++i)
118154 {
119155 if (result[j * shape[ndim - 1 ] + i] == val)
120156 {
121- ind = j * shape[ndim - 1 ] + i;
122- break ;
157+ _DataType change_val1 = result[j * shape[ndim - 1 ] + i];
158+ _DataType change_val2 = result[j * shape[ndim - 1 ] + k];
159+ result[j * shape[ndim - 1 ] + k] = change_val1;
160+ result[j * shape[ndim - 1 ] + i] = change_val2;
123161 }
124162 }
125163
126- _DataType change_val = result[j * shape[ndim - 1 ] + k];
127- result[j * shape[ndim - 1 ] + k] = val;
128- result[ind] = change_val;
129164 };
130165
131166 auto kernel_func = [&](cl::sycl::handler& cgh) {
167+ cgh.depends_on ({memcpy_event});
132168 cgh.parallel_for <class dpnp_partition_c_kernel <_DataType>>(gws, kernel_parallel_for_func);
133169 };
134170
135- event = DPNP_QUEUE.submit (kernel_func);
171+ auto event = DPNP_QUEUE.submit (kernel_func);
136172
137173 event.wait ();
174+
175+ dpnp_memory_free_c (shape);
138176}
139177
140178template <typename _DataType>
0 commit comments