|
23 | 23 | #pragma GCC diagnostic pop |
24 | 24 |
|
25 | 25 | #include "GPUCommonDef.h" |
| 26 | +#include "GPUCommonHelpers.h" |
26 | 27 |
|
27 | | -#ifdef __CUDACC__ |
| 28 | +#ifndef __HIPCC__ // CUDA |
28 | 29 | #define GPUCA_THRUST_NAMESPACE thrust::cuda |
29 | | -#else |
| 30 | +#define GPUCA_CUB_NAMESPACE cub |
| 31 | +#include <cub/cub.cuh> |
| 32 | +#else // HIP |
30 | 33 | #define GPUCA_THRUST_NAMESPACE thrust::hip |
| 34 | +#define GPUCA_CUB_NAMESPACE hipcub |
| 35 | +#include <hipcub/hipcub.hpp> |
31 | 36 | #endif |
32 | 37 |
|
33 | 38 | namespace o2::gpu |
@@ -89,11 +94,20 @@ template <class T, class S> |
89 | 94 | GPUhi() void GPUCommonAlgorithm::sortOnDevice(auto* rec, int32_t stream, T* begin, size_t N, const S& comp) |
90 | 95 | { |
91 | 96 | thrust::device_ptr<T> p(begin); |
| 97 | +#if 0 // Use Thrust |
92 | 98 | auto alloc = rec->getThrustVolatileDeviceAllocator(); |
93 | 99 | thrust::sort(GPUCA_THRUST_NAMESPACE::par(alloc).on(rec->mInternals->Streams[stream]), p, p + N, comp); |
| 100 | +#else // Use CUB |
| 101 | + size_t tempSize = 0; |
| 102 | + void* tempMem = nullptr; |
| 103 | + GPUChkErrS(GPUCA_CUB_NAMESPACE::DeviceMergeSort::SortKeys(tempMem, tempSize, begin, N, comp, rec->mInternals->Streams[stream])); |
| 104 | + tempMem = rec->AllocateVolatileDeviceMemory(tempSize); |
| 105 | + GPUChkErrS(GPUCA_CUB_NAMESPACE::DeviceMergeSort::SortKeys(tempMem, tempSize, begin, N, comp, rec->mInternals->Streams[stream])); |
| 106 | +#endif |
94 | 107 | } |
95 | 108 | } // namespace o2::gpu |
96 | 109 |
|
97 | 110 | #undef GPUCA_THRUST_NAMESPACE |
| 111 | +#undef GPUCA_CUB_NAMESPACE |
98 | 112 |
|
99 | 113 | #endif |
0 commit comments