|
23 | 23 | // THE POSSIBILITY OF SUCH DAMAGE. |
24 | 24 | //***************************************************************************** |
25 | 25 |
|
26 | | -#include <iostream> |
27 | | -#include <list> |
28 | | - |
29 | 26 | #include <dpnp_iface.hpp> |
30 | 27 | #include "dpnp_fptr.hpp" |
31 | | -#include "dpnp_utils.hpp" |
32 | 28 | #include "queue_sycl.hpp" |
33 | 29 |
|
34 | | -template <typename _DataType> |
| 30 | +template <typename _DataType, typename _IndecesType> |
35 | 31 | class dpnp_take_c_kernel; |
36 | 32 |
|
37 | | -template <typename _DataType> |
| 33 | +template <typename _DataType, typename _IndecesType> |
38 | 34 | void dpnp_take_c(void* array1_in, void* indices1, void* result1, size_t size) |
39 | 35 | { |
40 | 36 | _DataType* array_1 = reinterpret_cast<_DataType*>(array1_in); |
41 | 37 | _DataType* result = reinterpret_cast<_DataType*>(result1); |
42 | | - size_t* indices = reinterpret_cast<size_t*>(indices1); |
| 38 | + _IndecesType* indices = reinterpret_cast<_IndecesType*>(indices1); |
| 39 | + |
| 40 | + cl::sycl::range<1> gws(size); |
| 41 | + auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) { |
| 42 | + const size_t idx = global_id[0]; |
| 43 | + result[idx] = array_1[indices[idx]]; |
| 44 | + }; |
| 45 | + |
| 46 | + auto kernel_func = [&](cl::sycl::handler& cgh) { |
| 47 | + cgh.parallel_for<class dpnp_take_c_kernel<_DataType, _IndecesType>>(gws, kernel_parallel_for_func); |
| 48 | + }; |
| 49 | + |
| 50 | + cl::sycl::event event = DPNP_QUEUE.submit(kernel_func); |
43 | 51 |
|
44 | | - for (size_t i = 0; i < size; i++) |
45 | | - { |
46 | | - size_t ind = indices[i]; |
47 | | - result[i] = array_1[ind]; |
48 | | - } |
| 52 | + event.wait(); |
49 | 53 |
|
50 | 54 | return; |
51 | 55 | } |
52 | 56 |
|
53 | 57 | void func_map_init_indexing_func(func_map_t& fmap) |
54 | 58 | { |
55 | | - fmap[DPNPFuncName::DPNP_FN_TAKE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_take_c<int>}; |
56 | | - fmap[DPNPFuncName::DPNP_FN_TAKE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_take_c<long>}; |
57 | | - fmap[DPNPFuncName::DPNP_FN_TAKE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_take_c<float>}; |
58 | | - fmap[DPNPFuncName::DPNP_FN_TAKE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_take_c<double>}; |
| 59 | + fmap[DPNPFuncName::DPNP_FN_TAKE][eft_BOOL][eft_BOOL] = {eft_BOOL, (void*)dpnp_take_c<bool, long>}; |
| 60 | + fmap[DPNPFuncName::DPNP_FN_TAKE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_take_c<int, long>}; |
| 61 | + fmap[DPNPFuncName::DPNP_FN_TAKE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_take_c<long, long>}; |
| 62 | + fmap[DPNPFuncName::DPNP_FN_TAKE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_take_c<float, long>}; |
| 63 | + fmap[DPNPFuncName::DPNP_FN_TAKE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_take_c<double, long>}; |
| 64 | + fmap[DPNPFuncName::DPNP_FN_TAKE][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_take_c<std::complex<double>, long>}; |
59 | 65 |
|
60 | 66 | return; |
61 | 67 | } |
0 commit comments