@@ -34,7 +34,7 @@ struct unops_float_tests {
3434 ASSERT (equals (b[I], items[I] == 0.0 ? T (1.0 ) : T (0.0 )) && ...);
3535
3636 // Ideally, we would test all unary operators, but that would be a lot of work and not that useful since
37- // all operators are generators by the same macro. Instead, we only check a few of them
37+ // all operators are generated by the same macro. Instead, we only check a few of them
3838 if constexpr (is_one_of<T, __half, __nv_bfloat16>) {
3939 // operations on 16-bit numbers are only supported in CC >= 8
4040#if KERNEL_FLOAT_CUDA_ARCH >= 800
@@ -52,6 +52,12 @@ struct unops_float_tests {
5252
5353 b = exp (a);
5454 ASSERT (equals (b[I], hexp (T (items[I]))) && ...);
55+
56+ b = rcp (a);
57+ ASSERT (equals (b[I], hrcp (T (items[I]))) && ...);
58+
59+ b = rsqrt (a);
60+ ASSERT (equals (b[I], hrsqrt (T (items[I]))) && ...);
5561#endif
5662 } else {
5763 b = sqrt (a);
@@ -68,9 +74,15 @@ struct unops_float_tests {
6874
6975 b = exp (a);
7076 ASSERT (equals (b[I], exp (T (items[I]))) && ...);
77+
78+ b = rcp (a);
79+ ASSERT (equals (b[I], rcp (T (items[I]))) && ...);
80+
81+ b = rsqrt (a);
82+ ASSERT (equals (b[I], rsqrt (T (items[I]))) && ...);
7183 }
7284 }
7385};
7486
7587REGISTER_TEST_CASE (" unary float operators" , unops_float_tests, float , double )
76- REGISTER_TEST_CASE_GPU(" unary float operators" , unops_float_tests, __half, __nv_bfloat16)
88+ REGISTER_TEST_CASE_GPU(" unary float operators" , unops_float_tests, __half, __nv_bfloat16)
0 commit comments