|
| 1 | +#include <thrust/device_vector.h> |
| 2 | +#include <thrust/scan.h> |
| 3 | +#include <thrust/iterator/transform_iterator.h> |
| 4 | +#include <thrust/iterator/counting_iterator.h> |
| 5 | + |
| 6 | +#include <assert.h> |
| 7 | + |
| 8 | +// We have a matrix stored in a `thrust::device_vector`. We want to perform a |
| 9 | +// scan on each row of a matrix. |
| 10 | + |
| 11 | +__host__ |
| 12 | +void scan_matrix_by_rows0(thrust::device_vector<int>& u, int n, int m) { |
| 13 | + // Here, we launch a separate scan for each row in the matrix. This works, |
| 14 | + // but each kernel only does a small amount of work. It would be better if we |
| 15 | + // could launch one big kernel for the entire matrix. |
| 16 | + for (int i = 0; i < n; ++i) |
| 17 | + thrust::inclusive_scan(u.begin() + m * i, u.begin() + m * (i + 1), |
| 18 | + u.begin() + m * i); |
| 19 | +} |
| 20 | + |
| 21 | +// We can batch the operation using `thrust::inclusive_scan_by_key`, which |
| 22 | +// scans each group of consecutive equal keys. All we need to do is generate |
| 23 | +// the right key sequence. We want the keys for elements on the same row to |
| 24 | +// be identical. |
| 25 | + |
| 26 | +// So first, we define an unary function object which takes the index of an |
| 27 | +// element and returns the row that it belongs to. |
| 28 | + |
| 29 | +struct which_row : thrust::unary_function<int, int> { |
| 30 | + int row_length; |
| 31 | + |
| 32 | + __host__ __device__ |
| 33 | + which_row(int row_length_) : row_length(row_length_) {} |
| 34 | + |
| 35 | + __host__ __device__ |
| 36 | + int operator()(int idx) const { |
| 37 | + return idx / row_length; |
| 38 | + } |
| 39 | +}; |
| 40 | + |
| 41 | +__host__ |
| 42 | +void scan_matrix_by_rows1(thrust::device_vector<int>& u, int n, int m) { |
| 43 | + // This `thrust::counting_iterator` represents the index of the element. |
| 44 | + thrust::counting_iterator<int> c_first(0); |
| 45 | + |
| 46 | + // We construct a `thrust::transform_iterator` which applies the `which_row` |
| 47 | + // function object to the index of each element. |
| 48 | + thrust::transform_iterator<which_row, thrust::counting_iterator<int> > |
| 49 | + t_first(c_first, which_row(m)); |
| 50 | + |
| 51 | + // Finally, we use our `thrust::transform_iterator` as the key sequence to |
| 52 | + // `thrust::inclusive_scan_by_key`. |
| 53 | + thrust::inclusive_scan_by_key(t_first, t_first + n * m, u.begin(), u.begin()); |
| 54 | +} |
| 55 | + |
| 56 | +int main() { |
| 57 | + int const n = 4; |
| 58 | + int const m = 5; |
| 59 | + |
| 60 | + thrust::device_vector<int> u0(n * m); |
| 61 | + thrust::sequence(u0.begin(), u0.end()); |
| 62 | + scan_matrix_by_rows0(u0, n, m); |
| 63 | + |
| 64 | + thrust::device_vector<int> u1(n * m); |
| 65 | + thrust::sequence(u1.begin(), u1.end()); |
| 66 | + scan_matrix_by_rows1(u1, n, m); |
| 67 | + |
| 68 | + for (int i = 0; i < n; ++i) |
| 69 | + for (int j = 0; j < m; ++j) |
| 70 | + assert(u0[j + m * i] == u1[j + m * i]); |
| 71 | +} |
| 72 | + |
0 commit comments