2323// THE POSSIBILITY OF SUCH DAMAGE.
2424// *****************************************************************************
2525
26+ #include < pybind11/numpy.h>
27+ #include < pybind11/pybind11.h>
28+
29+ #include " ext/common.hpp"
30+
2631#include " ext/validation_utils.hpp"
2732#include " utils/memory_overlap.hpp"
2833
34+ namespace td_ns = dpctl::tensor::type_dispatch;
35+ namespace common = ext::common;
36+
2937namespace ext ::validation
3038{
3139inline sycl::queue get_queue (const std::vector<array_ptr> &inputs,
@@ -137,6 +145,15 @@ inline void check_num_dims(const array_ptr &arr,
137145 }
138146}
139147
148+ inline void check_num_dims (const std::vector<array_ptr> &arrays,
149+ const size_t ndim,
150+ const array_names &names)
151+ {
152+ for (const auto &arr : arrays) {
153+ check_num_dims (arr, ndim, names);
154+ }
155+ }
156+
140157inline void check_max_dims (const array_ptr &arr,
141158 const size_t max_ndim,
142159 const array_names &names)
@@ -163,6 +180,103 @@ inline void check_size_at_least(const array_ptr &arr,
163180 }
164181}
165182
183+ inline void check_has_dtype (const array_ptr &arr,
184+ const typenum_t dtype,
185+ const array_names &names)
186+ {
187+ if (arr == nullptr ) {
188+ return ;
189+ }
190+
191+ auto array_types = td_ns::usm_ndarray_types ();
192+ int array_type_id = array_types.typenum_to_lookup_id (arr->get_typenum ());
193+ int expected_type_id = static_cast <int >(dtype);
194+
195+ if (array_type_id != expected_type_id) {
196+ py::dtype actual_dtype = common::dtype_from_typenum (array_type_id);
197+ py::dtype dtype_py = common::dtype_from_typenum (expected_type_id);
198+
199+ std::string msg = " Array " + name_of (arr, names) + " must have dtype " +
200+ std::string (py::str (dtype_py)) + " , but got " +
201+ std::string (py::str (actual_dtype));
202+
203+ throw py::value_error (msg);
204+ }
205+ }
206+
207+ inline void check_same_dtype (const array_ptr &arr1,
208+ const array_ptr &arr2,
209+ const array_names &names)
210+ {
211+ if (arr1 == nullptr || arr2 == nullptr ) {
212+ return ;
213+ }
214+
215+ auto array_types = td_ns::usm_ndarray_types ();
216+ int first_type_id = array_types.typenum_to_lookup_id (arr1->get_typenum ());
217+ int second_type_id = array_types.typenum_to_lookup_id (arr2->get_typenum ());
218+
219+ if (first_type_id != second_type_id) {
220+ py::dtype first_dtype = common::dtype_from_typenum (first_type_id);
221+ py::dtype second_dtype = common::dtype_from_typenum (second_type_id);
222+
223+ std::string msg = " Arrays " + name_of (arr1, names) + " and " +
224+ name_of (arr2, names) +
225+ " must have the same dtype, but got " +
226+ std::string (py::str (first_dtype)) + " and " +
227+ std::string (py::str (second_dtype));
228+
229+ throw py::value_error (msg);
230+ }
231+ }
232+
233+ inline void check_same_dtype (const std::vector<array_ptr> &arrays,
234+ const array_names &names)
235+ {
236+ if (arrays.empty ()) {
237+ return ;
238+ }
239+
240+ const auto *first = arrays[0 ];
241+ for (size_t i = 1 ; i < arrays.size (); ++i) {
242+ check_same_dtype (first, arrays[i], names);
243+ }
244+ }
245+
246+ inline void check_same_size (const array_ptr &arr1,
247+ const array_ptr &arr2,
248+ const array_names &names)
249+ {
250+ if (arr1 == nullptr || arr2 == nullptr ) {
251+ return ;
252+ }
253+
254+ auto size1 = arr1->get_size ();
255+ auto size2 = arr2->get_size ();
256+
257+ if (size1 != size2) {
258+ std::string msg =
259+ " Arrays " + name_of (arr1, names) + " and " + name_of (arr2, names) +
260+ " must have the same size, but got " + std::to_string (size1) +
261+ " and " + std::to_string (size2);
262+
263+ throw py::value_error (msg);
264+ }
265+ }
266+
267+ inline void check_same_size (const std::vector<array_ptr> &arrays,
268+ const array_names &names)
269+ {
270+ if (arrays.empty ()) {
271+ return ;
272+ }
273+
274+ auto first = arrays[0 ];
275+ for (size_t i = 1 ; i < arrays.size (); ++i) {
276+ check_same_size (first, arrays[i], names);
277+ }
278+ }
279+
166280inline void common_checks (const std::vector<array_ptr> &inputs,
167281 const std::vector<array_ptr> &outputs,
168282 const array_names &names)
0 commit comments