|
34 | 34 | namespace mkl_blas = oneapi::mkl::blas::row_major; |
35 | 35 | namespace mkl_lapack = oneapi::mkl::lapack; |
36 | 36 |
|
37 | | -template <typename _DataType> |
38 | | -class dpnp_cholesky_c_kernel; |
39 | 37 |
|
40 | 38 | template <typename _DataType> |
41 | | -void dpnp_cholesky_c(void* array1_in, void* result1, size_t* shape) |
| 39 | +void dpnp_cholesky_c(void* array1_in, void* result1, const size_t size, const size_t data_size) |
42 | 40 | { |
43 | | - _DataType* array_1 = reinterpret_cast<_DataType*>(array1_in); |
44 | | - _DataType* l_result = reinterpret_cast<_DataType*>(result1); |
| 41 | + cl::sycl::event event; |
45 | 42 |
|
46 | | - size_t n = shape[0]; |
| 43 | + _DataType* in_array = reinterpret_cast<_DataType*>(array1_in); |
| 44 | + _DataType* result = reinterpret_cast<_DataType*>(result1); |
47 | 45 |
|
48 | | - l_result[0] = sqrt(array_1[0]); |
| 46 | + size_t iters = size / (data_size * data_size); |
49 | 47 |
|
50 | | - for (size_t j = 1; j < n; j++) |
| 48 | + for (size_t k = 0; k < iters; ++k) |
51 | 49 | { |
52 | | - l_result[j * n] = array_1[j * n] / l_result[0]; |
53 | | - } |
| 50 | + _DataType matrix[data_size * data_size]; |
| 51 | + _DataType result_[data_size * data_size]; |
54 | 52 |
|
55 | | - for (size_t i = 1; i < n; i++) |
56 | | - { |
57 | | - _DataType sum_val = 0; |
58 | | - for (size_t p = 0; p < i - 1; p++) |
| 53 | + for (size_t t = 0; t < data_size * data_size; ++t) |
59 | 54 | { |
60 | | - sum_val += l_result[i * n + p - 1] * l_result[i * n + p - 1]; |
| 55 | + matrix[t] = in_array[k * (data_size * data_size) + t]; |
| 56 | + |
61 | 57 | } |
62 | | - l_result[i * n + i - 1] = sqrt(array_1[i * n + i - 1] - sum_val); |
63 | | - } |
64 | 58 |
|
65 | | - for (size_t i = 1; i < n - 1; i++) |
66 | | - { |
67 | | - for (size_t j = i; j < n; j++) |
| 59 | + for (size_t it = 0; it < data_size * data_size; ++it) |
68 | 60 | { |
69 | | - _DataType sum_val = 0; |
70 | | - for (size_t p = 0; p < i - 1; p++) |
| 61 | + result_[it] = matrix[it]; |
| 62 | + } |
| 63 | + |
| 64 | + const std::int64_t n = data_size; |
| 65 | + |
| 66 | + const std::int64_t lda = std::max<size_t>(1UL, n); |
| 67 | + |
| 68 | + const std::int64_t scratchpad_size = mkl_lapack::potrf_scratchpad_size<_DataType>( |
| 69 | + DPNP_QUEUE, oneapi::mkl::uplo::upper, n, lda); |
| 70 | + |
| 71 | + _DataType* scratchpad = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(scratchpad_size * sizeof(_DataType))); |
| 72 | + |
| 73 | + event = mkl_lapack::potrf(DPNP_QUEUE, |
| 74 | + oneapi::mkl::uplo::upper, |
| 75 | + n, |
| 76 | + result_, |
| 77 | + lda, |
| 78 | + scratchpad, |
| 79 | + scratchpad_size); |
| 80 | + |
| 81 | + event.wait(); |
| 82 | + |
| 83 | + for (size_t i = 0; i < data_size; i++) |
| 84 | + { |
| 85 | + bool arg = false; |
| 86 | + for (size_t j = 0; j < data_size; j++) |
71 | 87 | { |
72 | | - sum_val += l_result[i * n + p - 1] * l_result[j * n + p - 1]; |
| 88 | + if (i == j - 1) |
| 89 | + { |
| 90 | + arg = true; |
| 91 | + } |
| 92 | + if (arg) |
| 93 | + { |
| 94 | + result_[i * data_size + j] = 0; |
| 95 | + } |
73 | 96 | } |
74 | | - l_result[j * n + i - 1] = (1 / l_result[i * n + i - 1]) * (array_1[j * n + i - 1] - sum_val); |
75 | 97 | } |
| 98 | + |
| 99 | + dpnp_memory_free_c(scratchpad); |
| 100 | + |
| 101 | + for (size_t t = 0; t < data_size * data_size; ++t) |
| 102 | + { |
| 103 | + result[k * (data_size * data_size) + t] = result_[t]; |
| 104 | + |
| 105 | + } |
| 106 | + |
76 | 107 | } |
77 | | - return; |
| 108 | + |
78 | 109 | } |
79 | 110 |
|
80 | 111 | template <typename _DataType> |
@@ -431,8 +462,6 @@ void dpnp_svd_c(void* array1_in, void* result1, void* result2, void* result3, si |
431 | 462 |
|
432 | 463 | void func_map_init_linalg_func(func_map_t& fmap) |
433 | 464 | { |
434 | | - fmap[DPNPFuncName::DPNP_FN_CHOLESKY][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_cholesky_c<int>}; |
435 | | - fmap[DPNPFuncName::DPNP_FN_CHOLESKY][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_cholesky_c<long>}; |
436 | 465 | fmap[DPNPFuncName::DPNP_FN_CHOLESKY][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_cholesky_c<float>}; |
437 | 466 | fmap[DPNPFuncName::DPNP_FN_CHOLESKY][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_cholesky_c<double>}; |
438 | 467 |
|
|
0 commit comments