3131#include < pybind11/stl.h>
3232
3333#include " common.hpp"
34+ #include " cos.hpp"
3435#include " div.hpp"
3536#include " ln.hpp"
37+ #include " sin.hpp"
3638#include " types_matrix.hpp"
3739
3840namespace py = pybind11;
@@ -43,7 +45,9 @@ using vm_ext::unary_impl_fn_ptr_t;
4345
4446static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
4547
48+ static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
4649static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
50+ static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types];
4751
4852PYBIND11_MODULE (_vm_impl, m)
4953{
@@ -80,6 +84,34 @@ PYBIND11_MODULE(_vm_impl, m)
8084 py::arg (" dst" ));
8185 }
8286
87+ // UnaryUfunc: ==== Cos(x) ====
88+ {
89+ vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t ,
90+ vm_ext::CosContigFactory>(
91+ cos_dispatch_vector);
92+
93+ auto cos_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
94+ const event_vecT &depends = {}) {
95+ return vm_ext::unary_ufunc (exec_q, src, dst, depends,
96+ cos_dispatch_vector);
97+ };
98+ m.def (" _cos" , cos_pyapi,
99+ " Call `cos` function from OneMKL VM library to compute "
100+ " cosine of vector elements" ,
101+ py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ),
102+ py::arg (" depends" ) = py::list ());
103+
104+ auto cos_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
105+ arrayT dst) {
106+ return vm_ext::need_to_call_unary_ufunc (exec_q, src, dst,
107+ cos_dispatch_vector);
108+ };
109+ m.def (" _mkl_cos_to_call" , cos_need_to_call_pyapi,
110+ " Check input arguments to answer if `cos` function from "
111+ " OneMKL VM library can be used" ,
112+ py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ));
113+ }
114+
83115 // UnaryUfunc: ==== Ln(x) ====
84116 {
85117 vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t ,
@@ -107,4 +139,32 @@ PYBIND11_MODULE(_vm_impl, m)
107139 " OneMKL VM library can be used" ,
108140 py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ));
109141 }
142+
143+ // UnaryUfunc: ==== Sin(x) ====
144+ {
145+ vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t ,
146+ vm_ext::SinContigFactory>(
147+ sin_dispatch_vector);
148+
149+ auto sin_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
150+ const event_vecT &depends = {}) {
151+ return vm_ext::unary_ufunc (exec_q, src, dst, depends,
152+ sin_dispatch_vector);
153+ };
154+ m.def (" _sin" , sin_pyapi,
155+ " Call `sin` function from OneMKL VM library to compute "
156+ " sine of vector elements" ,
157+ py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ),
158+ py::arg (" depends" ) = py::list ());
159+
160+ auto sin_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
161+ arrayT dst) {
162+ return vm_ext::need_to_call_unary_ufunc (exec_q, src, dst,
163+ sin_dispatch_vector);
164+ };
165+ m.def (" _mkl_sin_to_call" , sin_need_to_call_pyapi,
166+ " Check input arguments to answer if `sin` function from "
167+ " OneMKL VM library can be used" ,
168+ py::arg (" sycl_queue" ), py::arg (" src" ), py::arg (" dst" ));
169+ }
110170}
0 commit comments