Skip to content

Commit 1ef0374

Browse files
authored
Merge pull request NVIDIA#1376 from allisonvacanti/bug/scan_by_key_modernize/gh.1374
Modernize scan_by_key functors / type deductions.
2 parents 730c3bb + f7f2129 commit 1ef0374

4 files changed

Lines changed: 89 additions & 21 deletions

File tree

testing/scan_by_key.cu

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <unittest/unittest.h>
22
#include <thrust/scan.h>
33
#include <thrust/functional.h>
4+
#include <thrust/iterator/discard_iterator.h>
45
#include <thrust/iterator/transform_iterator.h>
56
#include <thrust/iterator/retag.h>
67
#include <thrust/random.h>
@@ -540,6 +541,74 @@ void TestScanByKeyMixedTypes(void)
540541
DECLARE_UNITTEST(TestScanByKeyMixedTypes);
541542

542543

544+
template <typename T>
545+
void TestScanByKeyDiscardOutput(std::size_t n)
546+
{
547+
thrust::host_vector<T> h_keys(n);
548+
thrust::default_random_engine rng;
549+
550+
for (size_t i = 0, k = 0; i < n; i++)
551+
{
552+
h_keys[i] = static_cast<T>(k);
553+
if (rng() % 10 == 0)
554+
{
555+
k++;
556+
}
557+
}
558+
thrust::device_vector<T> d_keys = h_keys;
559+
560+
thrust::host_vector<T> h_vals(n);
561+
for(size_t i = 0; i < n; i++)
562+
{
563+
h_vals[i] = static_cast<T>(i % 10);
564+
}
565+
thrust::device_vector<T> d_vals = h_vals;
566+
567+
auto out = thrust::make_discard_iterator();
568+
569+
// These are no-ops, but they should compile.
570+
thrust::exclusive_scan_by_key(d_keys.cbegin(),
571+
d_keys.cend(),
572+
d_vals.cbegin(),
573+
out);
574+
thrust::exclusive_scan_by_key(d_keys.cbegin(),
575+
d_keys.cend(),
576+
d_vals.cbegin(),
577+
out,
578+
T{});
579+
thrust::exclusive_scan_by_key(d_keys.cbegin(),
580+
d_keys.cend(),
581+
d_vals.cbegin(),
582+
out,
583+
T{},
584+
thrust::equal_to<T>{});
585+
thrust::exclusive_scan_by_key(d_keys.cbegin(),
586+
d_keys.cend(),
587+
d_vals.cbegin(),
588+
out,
589+
T{},
590+
thrust::equal_to<T>{},
591+
thrust::multiplies<T>{});
592+
593+
thrust::inclusive_scan_by_key(d_keys.cbegin(),
594+
d_keys.cend(),
595+
d_vals.cbegin(),
596+
out);
597+
thrust::inclusive_scan_by_key(d_keys.cbegin(),
598+
d_keys.cend(),
599+
d_vals.cbegin(),
600+
out,
601+
thrust::equal_to<T>{});
602+
thrust::inclusive_scan_by_key(d_keys.cbegin(),
603+
d_keys.cend(),
604+
d_vals.cbegin(),
605+
out,
606+
thrust::equal_to<T>{},
607+
thrust::multiplies<T>{});
608+
}
609+
DECLARE_VARIABLE_UNITTEST(TestScanByKeyDiscardOutput);
610+
611+
543612
void TestScanByKeyLargeInput()
544613
{
545614
const unsigned int N = 1 << 20;

thrust/system/cuda/detail/scan_by_key.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -844,14 +844,14 @@ inclusive_scan_by_key(execution_policy<Derived> &policy,
844844
ValOutputIt value_result,
845845
BinaryPred binary_pred)
846846
{
847-
typedef typename thrust::iterator_traits<ValOutputIt>::value_type value_type;
847+
typedef typename thrust::iterator_traits<ValInputIt>::value_type value_type;
848848
return cuda_cub::inclusive_scan_by_key(policy,
849849
key_first,
850850
key_last,
851851
value_first,
852852
value_result,
853853
binary_pred,
854-
plus<value_type>());
854+
thrust::plus<>());
855855
}
856856

857857
template <class Derived,
@@ -871,7 +871,7 @@ inclusive_scan_by_key(execution_policy<Derived> &policy,
871871
key_last,
872872
value_first,
873873
value_result,
874-
equal_to<key_type>());
874+
thrust::equal_to<>());
875875
}
876876

877877

@@ -948,7 +948,7 @@ exclusive_scan_by_key(execution_policy<Derived> &policy,
948948
value_result,
949949
init,
950950
binary_pred,
951-
plus<Init>());
951+
plus<>());
952952
}
953953

954954
template <class Derived,
@@ -971,7 +971,7 @@ exclusive_scan_by_key(execution_policy<Derived> &policy,
971971
value_first,
972972
value_result,
973973
init,
974-
equal_to<key_type>());
974+
equal_to<>());
975975
}
976976

977977

@@ -986,13 +986,13 @@ exclusive_scan_by_key(execution_policy<Derived> &policy,
986986
ValInputIt value_first,
987987
ValOutputIt value_result)
988988
{
989-
typedef typename iterator_traits<ValOutputIt>::value_type value_type;
989+
typedef typename iterator_traits<ValInputIt>::value_type value_type;
990990
return cuda_cub::exclusive_scan_by_key(policy,
991991
key_first,
992992
key_last,
993993
value_first,
994994
value_result,
995-
value_type(0));
995+
value_type{});
996996
}
997997

998998

thrust/system/detail/generic/scan_by_key.inl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818
#include <thrust/detail/config.h>
19+
#include <thrust/detail/cstdint.h>
1920
#include <thrust/system/detail/generic/scan_by_key.h>
2021
#include <thrust/functional.h>
2122
#include <thrust/transform.h>
@@ -71,8 +72,7 @@ __host__ __device__
7172
InputIterator2 first2,
7273
OutputIterator result)
7374
{
74-
typedef typename thrust::iterator_traits<InputIterator1>::value_type InputType1;
75-
return thrust::inclusive_scan_by_key(exec, first1, last1, first2, result, thrust::equal_to<InputType1>());
75+
return thrust::inclusive_scan_by_key(exec, first1, last1, first2, result, thrust::equal_to<>());
7676
}
7777

7878

@@ -108,8 +108,8 @@ __host__ __device__
108108
BinaryPredicate binary_pred,
109109
AssociativeOperator binary_op)
110110
{
111-
typedef typename thrust::iterator_traits<OutputIterator>::value_type OutputType;
112-
typedef unsigned int HeadFlagType;
111+
using OutputType = typename thrust::iterator_traits<InputIterator2>::value_type;
112+
using HeadFlagType = thrust::detail::uint32_t;
113113

114114
const size_t n = last1 - first1;
115115

@@ -146,8 +146,8 @@ __host__ __device__
146146
InputIterator2 first2,
147147
OutputIterator result)
148148
{
149-
typedef typename thrust::iterator_traits<OutputIterator>::value_type OutputType;
150-
return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, OutputType(0));
149+
typedef typename thrust::iterator_traits<InputIterator2>::value_type InitType;
150+
return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, InitType{});
151151
}
152152

153153

@@ -164,8 +164,7 @@ __host__ __device__
164164
OutputIterator result,
165165
T init)
166166
{
167-
typedef typename thrust::iterator_traits<InputIterator1>::value_type InputType1;
168-
return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, init, thrust::equal_to<InputType1>());
167+
return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, init, thrust::equal_to<>());
169168
}
170169

171170

@@ -205,8 +204,8 @@ __host__ __device__
205204
BinaryPredicate binary_pred,
206205
AssociativeOperator binary_op)
207206
{
208-
typedef typename thrust::iterator_traits<OutputIterator>::value_type OutputType;
209-
typedef unsigned int HeadFlagType;
207+
using OutputType = T;
208+
using HeadFlagType = thrust::detail::uint32_t;
210209

211210
const size_t n = last1 - first1;
212211

thrust/system/detail/sequential/scan_by_key.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ __host__ __device__
5252
BinaryPredicate binary_pred,
5353
BinaryFunction binary_op)
5454
{
55-
typedef typename thrust::iterator_traits<InputIterator1>::value_type KeyType;
56-
typedef typename thrust::iterator_traits<OutputIterator>::value_type ValueType;
55+
using KeyType = typename thrust::iterator_traits<InputIterator1>::value_type;
56+
using ValueType = typename thrust::iterator_traits<InputIterator2>::value_type;
5757

5858
// wrap binary_op
5959
thrust::detail::wrapped_function<
@@ -105,8 +105,8 @@ __host__ __device__
105105
BinaryPredicate binary_pred,
106106
BinaryFunction binary_op)
107107
{
108-
typedef typename thrust::iterator_traits<InputIterator1>::value_type KeyType;
109-
typedef typename thrust::iterator_traits<OutputIterator>::value_type ValueType;
108+
using KeyType = typename thrust::iterator_traits<InputIterator1>::value_type;
109+
using ValueType = T;
110110

111111
if(first1 != last1)
112112
{

0 commit comments

Comments
 (0)