@@ -8,13 +8,13 @@ struct unops_tests {
88 kf::vec<T, N> b;
99
1010 b = -a;
11- ASSERT ( equals ( b[I], T (-items[I])) && ... );
11+ ASSERT_EQ_ALL ( b[I], T (-items[I]));
1212
1313 b = ~a;
14- ASSERT ( equals ( b[I], T (~items[I])) && ... );
14+ ASSERT_EQ_ALL ( b[I], T (~items[I]));
1515
1616 b = !a;
17- ASSERT ( equals ( b[I], T (!items[I])) && ... );
17+ ASSERT_EQ_ALL ( b[I], T (!items[I]));
1818 }
1919};
2020
@@ -28,58 +28,59 @@ struct unops_float_tests {
2828 kf::vec<T, N> b;
2929
3030 b = -a;
31- ASSERT ( equals ( b[I], T (-items[I])) && ... );
31+ ASSERT_EQ_ALL ( b[I], T (-items[I]));
3232
3333 b = !a;
34- ASSERT ( equals ( b[I], items[I] == 0.0 ? T (1.0 ) : T (0.0 )) && ... );
34+ ASSERT_EQ_ALL ( b[I], items[I] == 0.0 ? T (1.0 ) : T (0.0 ));
3535
3636 // Ideally, we would test all unary operators, but that would be a lot of work and not that useful since
3737 // all operators are generated by the same macro. Instead, we only check a few of them
3838 if constexpr (is_one_of<T, __half, __nv_bfloat16>) {
3939 // operations on 16-bit numbers are only supported in CC >= 8
4040#if KERNEL_FLOAT_CUDA_ARCH >= 800
4141 b = sqrt (a);
42- ASSERT ( equals ( b[I], hsqrt (T (items[I]))) && ... );
42+ ASSERT_EQ_ALL ( b[I], hsqrt (T (items[I])));
4343
4444 b = sin (a);
45- ASSERT ( equals ( b[I], hsin (T (items[I]))) && ... );
45+ ASSERT_EQ_ALL ( b[I], hsin (T (items[I])));
4646
4747 b = cos (a);
48- ASSERT ( equals ( b[I], hcos (T (items[I]))) && ... );
48+ ASSERT_EQ_ALL ( b[I], hcos (T (items[I])));
4949
5050 b = log (a);
51- ASSERT ( equals ( b[I], hlog (T (items[I]))) && ... );
51+ ASSERT_EQ_ALL ( b[I], hlog (T (items[I])));
5252
5353 b = exp (a);
54- ASSERT ( equals ( b[I], hexp (T (items[I]))) && ... );
54+ ASSERT_EQ_ALL ( b[I], hexp (T (items[I])));
5555
5656 b = rcp (a);
57- ASSERT ( equals ( b[I], hrcp (T (items[I]))) && ... );
57+ ASSERT_EQ_ALL ( b[I], hrcp (T (items[I])));
5858
5959 b = rsqrt (a);
60- ASSERT ( equals ( b[I], hrsqrt (T (items[I]))) && ... );
60+ ASSERT_EQ_ALL ( b[I], hrsqrt (T (items[I])));
6161#endif
6262 } else {
6363 b = sqrt (a);
64- ASSERT ( equals ( b[I], sqrt (T (items[I]))) && ... );
64+ ASSERT_EQ_ALL ( b[I], sqrt (T (items[I])));
6565
6666 b = sin (a);
67- ASSERT ( equals ( b[I], sin (T (items[I]))) && ... );
67+ ASSERT_EQ_ALL ( b[I], sin (T (items[I])));
6868
6969 b = cos (a);
70- ASSERT ( equals ( b[I], cos (T (items[I]))) && ... );
70+ ASSERT_EQ_ALL ( b[I], cos (T (items[I])));
7171
7272 b = log (a);
73- ASSERT ( equals ( b[I], log (T (items[I]))) && ... );
73+ ASSERT_EQ_ALL ( b[I], log (T (items[I])));
7474
7575 b = exp (a);
76- ASSERT ( equals ( b[I], exp (T (items[I]))) && ... );
76+ ASSERT_EQ_ALL ( b[I], exp (T (items[I])));
7777
7878 b = rcp (a);
79- ASSERT ( equals ( b[I], rcp ( T ( items[I]))) && ... );
79+ ASSERT_EQ_ALL ( b[I], T ( 1.0 / items[I]));
8080
81+ // seems that rsqrt does not match bitwise on GPU
8182 b = rsqrt (a);
82- ASSERT ( equals ( b[I], rsqrt (T (items[I]))) && ... );
83+ ASSERT_APPROX_ALL ( b[I], rsqrt (T (items[I])));
8384 }
8485 }
8586};
0 commit comments