Skip to content

Commit c34e4f9

Browse files
native kernel for funcs: ones, ones_like, zeros, zeros_like, full, full_like (#536)
* native kernel for funcs: ones, ones_like, zeros, zeros_like, full, full_like
1 parent 96bfc2f commit c34e4f9

9 files changed

Lines changed: 170 additions & 98 deletions

File tree

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,17 @@ INP_DLLEXPORT void dpnp_cov_c(void* array1_in, void* result1, size_t nrows, size
284284
template <typename _DataType>
285285
INP_DLLEXPORT void dpnp_det_c(void* array1_in, void* result1, size_t* shape, size_t ndim);
286286

287+
/**
288+
* @ingroup BACKEND_API
289+
* @brief implementation of creating filled with value array function
290+
*
291+
* @param [out] result Output array.
292+
* @param [in] value Value in array.
293+
* @param [in] size Number of elements in array.
294+
*/
295+
template <typename _DataType>
296+
INP_DLLEXPORT void dpnp_initval_c(void* result1, void* value, size_t size);
297+
287298
/**
288299
* @ingroup BACKEND_API
289300
* @brief math library implementation of inv function

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ enum class DPNPFuncName : size_t
100100
DPNP_FN_FLOOR_DIVIDE, /**< Used in numpy.floor_divide() implementation */
101101
DPNP_FN_FMOD, /**< Used in numpy.fmod() implementation */
102102
DPNP_FN_HYPOT, /**< Used in numpy.hypot() implementation */
103+
DPNP_FN_INITVAL, /**< Used in numpy ones, ones_like, zeros, zeros_like, full, full_like impl */
103104
DPNP_FN_INV, /**< Used in numpy.linalg.inv() implementation */
104105
DPNP_FN_INVERT, /**< Used in numpy.invert() implementation */
105106
DPNP_FN_KRON, /**< Used in numpy.kron() implementation */
@@ -181,12 +182,13 @@ enum class DPNPFuncName : size_t
181182
*/
182183
enum class DPNPFuncType : size_t
183184
{
184-
DPNP_FT_NONE, /**< Very first element of the enumeration */
185-
DPNP_FT_INT, /**< analog of numpy.int32 or int */
186-
DPNP_FT_LONG, /**< analog of numpy.int64 or long */
187-
DPNP_FT_FLOAT, /**< analog of numpy.float32 or float */
188-
DPNP_FT_DOUBLE, /**< analog of numpy.float32 or double */
189-
DPNP_FT_CMPLX128 /**< analog of numpy.complex128 or std::complex<double> */
185+
DPNP_FT_NONE, /**< Very first element of the enumeration */
186+
DPNP_FT_INT, /**< analog of numpy.int32 or int */
187+
DPNP_FT_LONG, /**< analog of numpy.int64 or long */
188+
DPNP_FT_FLOAT, /**< analog of numpy.float32 or float */
189+
DPNP_FT_DOUBLE, /**< analog of numpy.float32 or double */
190+
DPNP_FT_CMPLX128, /**< analog of numpy.complex128 or std::complex<double> */
191+
DPNP_FT_BOOL /**< analog of numpy.bool or numpy.bool_ or bool */
190192
};
191193

192194
/**

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 112 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -35,82 +35,6 @@
3535
namespace mkl_blas = oneapi::mkl::blas;
3636
namespace mkl_lapack = oneapi::mkl::lapack;
3737

38-
template <typename _KernelNameSpecialization>
39-
class dpnp_matmul_c_kernel;
40-
41-
template <typename _DataType>
42-
void dpnp_matmul_c(void* array1_in, void* array2_in, void* result1, size_t size_m, size_t size_n, size_t size_k)
43-
{
44-
cl::sycl::event event;
45-
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
46-
_DataType* array_2 = reinterpret_cast<_DataType*>(array2_in);
47-
_DataType* result = reinterpret_cast<_DataType*>(result1);
48-
49-
if (!size_m || !size_n || !size_k)
50-
{
51-
return;
52-
}
53-
54-
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
55-
{
56-
// using std::max for these ldx variables is required by math library
57-
const std::int64_t lda = std::max<size_t>(1UL, size_k); // First dimensions of array_1
58-
const std::int64_t ldb = std::max<size_t>(1UL, size_n); // First dimensions of array_2
59-
const std::int64_t ldc = std::max<size_t>(1UL, size_n); // Fast dimensions of result
60-
61-
event = mkl_blas::gemm(DPNP_QUEUE,
62-
oneapi::mkl::transpose::nontrans,
63-
oneapi::mkl::transpose::nontrans,
64-
size_n,
65-
size_m,
66-
size_k,
67-
_DataType(1),
68-
array_2,
69-
ldb,
70-
array_1,
71-
lda,
72-
_DataType(0),
73-
result,
74-
ldc);
75-
}
76-
else
77-
{
78-
// input1: M x K
79-
// input2: K x N
80-
// result: M x N
81-
const size_t dim_m = size_m; // shape1.front(); // First dimensions of array1
82-
const size_t dim_n = size_n; // shape2.back(); // Last dimensions of array2
83-
const size_t dim_k = size_k; // shape1.back(); // First dimensions of array2
84-
85-
cl::sycl::range<2> gws(dim_m, dim_n); // dimensions are: "i" and "j"
86-
87-
auto kernel_parallel_for_func = [=](cl::sycl::id<2> global_id) {
88-
size_t i = global_id[0]; //for (size_t i = 0; i < size; ++i)
89-
{
90-
size_t j = global_id[1]; //for (size_t j = 0; j < size; ++j)
91-
{
92-
_DataType acc = _DataType(0);
93-
for (size_t k = 0; k < dim_k; ++k)
94-
{
95-
const size_t index_1 = i * dim_k + k;
96-
const size_t index_2 = k * dim_n + j;
97-
acc += array_1[index_1] * array_2[index_2];
98-
}
99-
const size_t index_result = i * dim_n + j;
100-
result[index_result] = acc;
101-
}
102-
}
103-
};
104-
105-
auto kernel_func = [&](cl::sycl::handler& cgh) {
106-
cgh.parallel_for<class dpnp_matmul_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
107-
};
108-
109-
event = DPNP_QUEUE.submit(kernel_func);
110-
}
111-
event.wait();
112-
}
113-
11438
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
11539
class dpnp_dot_c_kernel;
11640

@@ -293,6 +217,111 @@ void dpnp_eigvals_c(const void* array_in, void* result1, size_t size)
293217
dpnp_memory_free_c(result_val_kern);
294218
}
295219

220+
template <typename _DataType>
221+
class dpnp_initval_c_kernel;
222+
223+
template <typename _DataType>
224+
void dpnp_initval_c(void* result1, void* value, size_t size)
225+
{
226+
if (!size)
227+
{
228+
return;
229+
}
230+
231+
_DataType* result = reinterpret_cast<_DataType*>(result1);
232+
_DataType val = *(reinterpret_cast<_DataType*>(value));
233+
234+
cl::sycl::range<1> gws(size);
235+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
236+
const size_t idx = global_id[0];
237+
result[idx] = val;
238+
};
239+
240+
auto kernel_func = [&](cl::sycl::handler& cgh) {
241+
cgh.parallel_for<class dpnp_initval_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
242+
};
243+
244+
cl::sycl::event event = DPNP_QUEUE.submit(kernel_func);
245+
246+
event.wait();
247+
}
248+
249+
template <typename _KernelNameSpecialization>
250+
class dpnp_matmul_c_kernel;
251+
252+
template <typename _DataType>
253+
void dpnp_matmul_c(void* array1_in, void* array2_in, void* result1, size_t size_m, size_t size_n, size_t size_k)
254+
{
255+
cl::sycl::event event;
256+
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
257+
_DataType* array_2 = reinterpret_cast<_DataType*>(array2_in);
258+
_DataType* result = reinterpret_cast<_DataType*>(result1);
259+
260+
if (!size_m || !size_n || !size_k)
261+
{
262+
return;
263+
}
264+
265+
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
266+
{
267+
// using std::max for these ldx variables is required by math library
268+
const std::int64_t lda = std::max<size_t>(1UL, size_k); // First dimensions of array_1
269+
const std::int64_t ldb = std::max<size_t>(1UL, size_n); // First dimensions of array_2
270+
const std::int64_t ldc = std::max<size_t>(1UL, size_n); // Fast dimensions of result
271+
272+
event = mkl_blas::gemm(DPNP_QUEUE,
273+
oneapi::mkl::transpose::nontrans,
274+
oneapi::mkl::transpose::nontrans,
275+
size_n,
276+
size_m,
277+
size_k,
278+
_DataType(1),
279+
array_2,
280+
ldb,
281+
array_1,
282+
lda,
283+
_DataType(0),
284+
result,
285+
ldc);
286+
}
287+
else
288+
{
289+
// input1: M x K
290+
// input2: K x N
291+
// result: M x N
292+
const size_t dim_m = size_m; // shape1.front(); // First dimensions of array1
293+
const size_t dim_n = size_n; // shape2.back(); // Last dimensions of array2
294+
const size_t dim_k = size_k; // shape1.back(); // First dimensions of array2
295+
296+
cl::sycl::range<2> gws(dim_m, dim_n); // dimensions are: "i" and "j"
297+
298+
auto kernel_parallel_for_func = [=](cl::sycl::id<2> global_id) {
299+
size_t i = global_id[0]; //for (size_t i = 0; i < size; ++i)
300+
{
301+
size_t j = global_id[1]; //for (size_t j = 0; j < size; ++j)
302+
{
303+
_DataType acc = _DataType(0);
304+
for (size_t k = 0; k < dim_k; ++k)
305+
{
306+
const size_t index_1 = i * dim_k + k;
307+
const size_t index_2 = k * dim_n + j;
308+
acc += array_1[index_1] * array_2[index_2];
309+
}
310+
const size_t index_result = i * dim_n + j;
311+
result[index_result] = acc;
312+
}
313+
}
314+
};
315+
316+
auto kernel_func = [&](cl::sycl::handler& cgh) {
317+
cgh.parallel_for<class dpnp_matmul_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
318+
};
319+
320+
event = DPNP_QUEUE.submit(kernel_func);
321+
}
322+
event.wait();
323+
}
324+
296325
void func_map_init_linalg(func_map_t& fmap)
297326
{
298327
fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_dot_c<int, int, int>};
@@ -321,6 +350,13 @@ void func_map_init_linalg(func_map_t& fmap)
321350
fmap[DPNPFuncName::DPNP_FN_EIGVALS][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_eigvals_c<float, float>};
322351
fmap[DPNPFuncName::DPNP_FN_EIGVALS][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_eigvals_c<double, double>};
323352

353+
fmap[DPNPFuncName::DPNP_FN_INITVAL][eft_BOOL][eft_BOOL] = {eft_BOOL, (void*)dpnp_initval_c<bool>};
354+
fmap[DPNPFuncName::DPNP_FN_INITVAL][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_initval_c<int>};
355+
fmap[DPNPFuncName::DPNP_FN_INITVAL][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_initval_c<long>};
356+
fmap[DPNPFuncName::DPNP_FN_INITVAL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_initval_c<float>};
357+
fmap[DPNPFuncName::DPNP_FN_INITVAL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_initval_c<double>};
358+
fmap[DPNPFuncName::DPNP_FN_INITVAL][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_initval_c<std::complex<double>>};
359+
324360
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_matmul_c<int>};
325361
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_matmul_c<long>};
326362
fmap[DPNPFuncName::DPNP_FN_MATMUL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_matmul_c<float>};

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,17 +285,17 @@ void dpnp_arange_c(size_t start, size_t step, void* result1, size_t size)
285285

286286
static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
287287
{
288-
fmap[DPNPFuncName::DPNP_FN_ARANGE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_arange_c<double>};
289-
fmap[DPNPFuncName::DPNP_FN_ARANGE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_arange_c<float>};
290288
fmap[DPNPFuncName::DPNP_FN_ARANGE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_arange_c<int>};
291289
fmap[DPNPFuncName::DPNP_FN_ARANGE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_arange_c<long>};
290+
fmap[DPNPFuncName::DPNP_FN_ARANGE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_arange_c<float>};
291+
fmap[DPNPFuncName::DPNP_FN_ARANGE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_arange_c<double>};
292292

293-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_C128][eft_C128] = {eft_C128,
294-
(void*)dpnp_conjugate_c<std::complex<double>>};
295-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_copy_c<double>};
296-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_copy_c<float>};
297293
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_copy_c<int>};
298294
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_copy_c<long>};
295+
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_copy_c<float>};
296+
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_copy_c<double>};
297+
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_C128][eft_C128] = {eft_C128,
298+
(void*)dpnp_conjugate_c<std::complex<double>>};
299299

300300
fmap[DPNPFuncName::DPNP_FN_ERF][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_erf_c<int>};
301301
fmap[DPNPFuncName::DPNP_FN_ERF][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_erf_c<long>};

dpnp/backend/src/dpnp_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ const DPNPFuncType eft_LNG = DPNPFuncType::DPNP_FT_LONG;
6161
const DPNPFuncType eft_FLT = DPNPFuncType::DPNP_FT_FLOAT;
6262
const DPNPFuncType eft_DBL = DPNPFuncType::DPNP_FT_DOUBLE;
6363
const DPNPFuncType eft_C128 = DPNPFuncType::DPNP_FT_CMPLX128;
64+
const DPNPFuncType eft_BOOL = DPNPFuncType::DPNP_FT_BOOL;
6465

6566
/**
6667
* FPTR interface initialization functions

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
7373
DPNP_FN_FLOOR_DIVIDE
7474
DPNP_FN_FMOD
7575
DPNP_FN_HYPOT
76+
DPNP_FN_INITVAL
7677
DPNP_FN_INV
7778
DPNP_FN_INVERT
7879
DPNP_FN_KRON
@@ -151,6 +152,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncType": # need this na
151152
DPNP_FT_FLOAT
152153
DPNP_FT_DOUBLE
153154
DPNP_FT_CMPLX128
155+
DPNP_FT_BOOL
154156

155157
cdef extern from "dpnp_iface_fptr.hpp":
156158
struct DPNPFuncData:

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ include "dpnp_algo_trigonometric.pyx"
6969

7070

7171
ctypedef void(*fptr_dpnp_arange_t)(size_t, size_t, void * , size_t)
72+
ctypedef void(*fptr_dpnp_initval_t)(void * , void * , size_t)
73+
7274

7375
cpdef dparray dpnp_arange(start, stop, step, dtype):
7476

@@ -132,10 +134,19 @@ cpdef dparray dpnp_astype(dparray array1, dtype_target):
132134

133135

134136
cpdef dparray dpnp_init_val(shape, dtype, value):
137+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(dtype)
138+
139+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_INITVAL, param1_type, param1_type)
140+
141+
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
135142
cdef dparray result = dparray(shape, dtype=dtype)
136143

137-
for i in range(result.size):
138-
result[i] = value
144+
# TODO: find better way to pass single value with type conversion
145+
cdef dparray val_arr = dparray((1, ), dtype=dtype)
146+
val_arr[0] = value
147+
148+
cdef fptr_dpnp_initval_t func = <fptr_dpnp_initval_t > kernel_data.ptr
149+
func(result.get_data(), val_arr.get_data(), result.size)
139150

140151
return result
141152

@@ -244,16 +255,18 @@ Internal functions
244255
"""
245256
cpdef DPNPFuncType dpnp_dtype_to_DPNPFuncType(dtype):
246257

247-
if dtype == numpy.float64:
258+
if dtype in [numpy.float64, 'float64']:
248259
return DPNP_FT_DOUBLE
249-
elif dtype == numpy.float32:
260+
elif dtype in [numpy.float32, 'float32']:
250261
return DPNP_FT_FLOAT
251-
elif dtype == numpy.int64:
262+
elif dtype in [numpy.int64, 'int64', 'int', int]:
252263
return DPNP_FT_LONG
253-
elif dtype == numpy.int32:
264+
elif dtype in [numpy.int32, 'int32']:
254265
return DPNP_FT_INT
255-
elif dtype == numpy.complex128:
266+
elif dtype in [numpy.complex128, 'complex128']:
256267
return DPNP_FT_CMPLX128
268+
elif dtype in [numpy.bool, numpy.bool_, 'bool']:
269+
return DPNP_FT_BOOL
257270
else:
258271
checker_throw_type_error("dpnp_dtype_to_DPNPFuncType", dtype)
259272

@@ -272,6 +285,8 @@ cpdef dpnp_DPNPFuncType_to_dtype(size_t type):
272285
return numpy.int32
273286
elif type == <size_t > DPNP_FT_CMPLX128:
274287
return numpy.complex128
288+
elif type == <size_t > DPNP_FT_BOOL:
289+
return numpy.bool
275290
else:
276291
checker_throw_type_error("dpnp_DPNPFuncType_to_dtype", type)
277292

dpnp/dpnp_iface_arraycreation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ def full(shape, fill_value, dtype=None, order='C'):
616616
if order not in ('C', 'c', None):
617617
checker_throw_value_error("full", "order", order, 'C')
618618

619-
_dtype = dtype if dtype is not None else type(fill_value)
619+
_dtype = dtype if dtype is not None else dpnp.dtype(type(fill_value))
620620

621621
return dpnp_init_val(shape, _dtype, fill_value)
622622

@@ -981,7 +981,9 @@ def ones(shape, dtype=None, order='C'):
981981
if order not in ('C', 'c', None):
982982
checker_throw_value_error("ones", "order", order, 'C')
983983

984-
return dpnp_init_val(shape, dtype, 1)
984+
_dtype = dtype if dtype is not None else dpnp.float64
985+
986+
return dpnp_init_val(shape, _dtype, 1)
985987

986988
return numpy.ones(shape, dtype=dtype, order=order)
987989

@@ -1148,7 +1150,9 @@ def zeros(shape, dtype=None, order='C'):
11481150
if order not in ('C', 'c', None):
11491151
checker_throw_value_error("zeros", "order", order, 'C')
11501152

1151-
return dpnp_init_val(shape, dtype, 0)
1153+
_dtype = dtype if dtype is not None else dpnp.float64
1154+
1155+
return dpnp_init_val(shape, _dtype, 0)
11521156

11531157
return numpy.zeros(shape, dtype=dtype, order=order)
11541158

0 commit comments

Comments
 (0)