Skip to content

Commit cec823a

Browse files
FIX: random.gumbel segfault on GPU (#560)
* FIX: random.gumbel segfault on GPU
1 parent c34e4f9 commit cec823a

2 files changed

Lines changed: 24 additions & 8 deletions

File tree

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@ void dpnp_rng_geometric_c(void* result, const float p, const size_t size)
298298
event_out.wait();
299299
}
300300

301+
template <typename _KernelNameSpecialization>
302+
class dpnp_blas_scal_c_kernel;
303+
301304
template <typename _DataType>
302305
void dpnp_rng_gumbel_c(void* result, const double loc, const double scale, const size_t size)
303306
{
@@ -308,15 +311,33 @@ void dpnp_rng_gumbel_c(void* result, const double loc, const double scale, const
308311
}
309312

310313
const _DataType alpha = (_DataType(-1.0));
311-
const _DataType stride = (_DataType(1.0));
314+
std::int64_t incx = 1;
312315
_DataType* result1 = reinterpret_cast<_DataType*>(result);
313316
double negloc = loc * (double(-1.0));
314317

315318
mkl_rng::gumbel<_DataType> distribution(negloc, scale);
316-
// perform generation
317319
event = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
318320
event.wait();
319-
event = mkl_blas::scal(DPNP_QUEUE, size, alpha, result1, stride);
321+
322+
// OK for CPU and segfault for GPU device
323+
// event = mkl_blas::scal(DPNP_QUEUE, size, alpha, result1, incx);
324+
if (dpnp_queue_is_cpu_c())
325+
{
326+
event = mkl_blas::scal(DPNP_QUEUE, size, alpha, result1, incx);
327+
}
328+
else
329+
{
330+
// for (size_t i = 0; i < size; i++) result1[i] *= alpha;
331+
cl::sycl::range<1> gws(size);
332+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
333+
size_t i = global_id[0];
334+
result1[i] *= alpha;
335+
};
336+
auto kernel_func = [&](cl::sycl::handler& cgh) {
337+
cgh.parallel_for<class dpnp_blas_scal_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
338+
};
339+
event = DPNP_QUEUE.submit(kernel_func);
340+
}
320341
event.wait();
321342
}
322343

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
// gumbel distribution issue https://github.com/IntelPython/dpnp/issues/569
2-
tests/test_random.py::TestDistributionsGumbel::test_extreme_value
3-
tests/test_random.py::TestDistributionsGumbel::test_invalid_args
4-
tests/test_random.py::TestDistributionsGumbel::test_moments
5-
tests/test_random.py::TestDistributionsGumbel::test_seed
61
//----------------------------------------------------------------------
72
// eig/eigvals/svd issue https://github.com/IntelPython/dpnp/issues/567
83
tests/test_linalg.py::test_eig_arange[2-float64]

0 commit comments

Comments
 (0)