Skip to content

Commit 823a8cf

Browse files
committed
Fix cast on constant<T>
1 parent 8f363a2 commit 823a8cf

5 files changed

Lines changed: 92 additions & 27 deletions

File tree

include/kernel_float/constant.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,20 @@ struct cast<constant<T>, R, m> {
9898
return cast<T, R, m> {}(input);
9999
}
100100
};
101+
102+
template<typename T>
103+
struct cast<constant<T>, float> {
104+
KERNEL_FLOAT_INLINE float operator()(const T& input) noexcept {
105+
return cast<T, float> {}(input);
106+
}
107+
};
108+
109+
template<typename T, RoundingMode m>
110+
struct cast<constant<T>, float, m> {
111+
KERNEL_FLOAT_INLINE float operator()(const T& input) noexcept {
112+
return cast<T, float, m> {}(input);
113+
}
114+
};
101115
} // namespace ops
102116

103117
#define KERNEL_FLOAT_CONSTANT_DEFINE_OP(OP) \
@@ -140,6 +154,13 @@ KERNEL_FLOAT_CONSTANT_DEFINE_OP(*)
140154
KERNEL_FLOAT_CONSTANT_DEFINE_OP(/)
141155
KERNEL_FLOAT_CONSTANT_DEFINE_OP(%)
142156

157+
KERNEL_FLOAT_CONSTANT_DEFINE_OP(==)
158+
KERNEL_FLOAT_CONSTANT_DEFINE_OP(!=)
159+
KERNEL_FLOAT_CONSTANT_DEFINE_OP(<=)
160+
KERNEL_FLOAT_CONSTANT_DEFINE_OP(>=)
161+
KERNEL_FLOAT_CONSTANT_DEFINE_OP(<)
162+
KERNEL_FLOAT_CONSTANT_DEFINE_OP(>)
163+
143164
} // namespace kernel_float
144165

145166
#endif

single_include/kernel_float.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2026-04-09 10:28:01.035452
20-
// git hash: d4ea7202dd88aa23b79653ba45ffca3162e213bc
19+
// date: 2026-04-10 17:02:33.335438
20+
// git hash: 8f363a2146aff48ac4afc71f2283d91a6f1f65dd
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -2225,6 +2225,20 @@ struct cast<constant<T>, R, m> {
22252225
return cast<T, R, m> {}(input);
22262226
}
22272227
};
2228+
2229+
template<typename T>
2230+
struct cast<constant<T>, float> {
2231+
KERNEL_FLOAT_INLINE float operator()(const T& input) noexcept {
2232+
return cast<T, float> {}(input);
2233+
}
2234+
};
2235+
2236+
template<typename T, RoundingMode m>
2237+
struct cast<constant<T>, float, m> {
2238+
KERNEL_FLOAT_INLINE float operator()(const T& input) noexcept {
2239+
return cast<T, float, m> {}(input);
2240+
}
2241+
};
22282242
} // namespace ops
22292243

22302244
#define KERNEL_FLOAT_CONSTANT_DEFINE_OP(OP) \
@@ -2267,6 +2281,13 @@ KERNEL_FLOAT_CONSTANT_DEFINE_OP(*)
22672281
KERNEL_FLOAT_CONSTANT_DEFINE_OP(/)
22682282
KERNEL_FLOAT_CONSTANT_DEFINE_OP(%)
22692283

2284+
KERNEL_FLOAT_CONSTANT_DEFINE_OP(==)
2285+
KERNEL_FLOAT_CONSTANT_DEFINE_OP(!=)
2286+
KERNEL_FLOAT_CONSTANT_DEFINE_OP(<=)
2287+
KERNEL_FLOAT_CONSTANT_DEFINE_OP(>=)
2288+
KERNEL_FLOAT_CONSTANT_DEFINE_OP(<)
2289+
KERNEL_FLOAT_CONSTANT_DEFINE_OP(>)
2290+
22702291
} // namespace kernel_float
22712292

22722293
#endif

tests/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ struct device_runner {
351351
template<typename T, size_t N>
352352
void run() {
353353
if (cudaSetDevice(0) != cudaSuccess) {
354-
FAIL("failed to initialize CUDA device, does this machine have a GPU?");
354+
FAIL("failed to initialize CUDA device, run with '~[GPU]' to skip GPU tests");
355355
}
356356

357357
for (int seed = 0; seed < 5; seed++) {

tests/constant.cu

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
#include "common.h"
22

3-
#define ASSERT_TYPE(A, B) ASSERT(std::is_same<decltype(A), B>::value);
4-
5-
struct constant_tests {
3+
struct constant_ops_tests {
64
template<typename T>
75
__host__ __device__ void operator()(generator<T> gen) {
86
T value = gen.next();
@@ -33,8 +31,32 @@ struct constant_tests {
3331
// ASSERT_EQ(value % kf::make_constant(5.0), value % T(5));
3432
// ASSERT_EQ(kf::make_constant(5.0) % vector, T(5) % vector);
3533
// ASSERT_EQ(vector % kf::make_constant(5.0), vector % T(5));
34+
35+
ASSERT_EQ(kf::cast<double>(kf::make_constant(T(5.0))), kf::make_vec(5.0));
36+
ASSERT_EQ(kf::cast<float>(kf::make_constant(T(5.0))), kf::make_vec(5.0f));
37+
ASSERT_EQ(kf::cast<int>(kf::make_constant(T(5.0))), kf::make_vec(5));
38+
39+
ASSERT_EQ(kf::cast<T>(kf::make_constant(double(5.0))), kf::make_vec(T(5.0)));
40+
ASSERT_EQ(kf::cast<T>(kf::make_constant(float(5.0))), kf::make_vec(T(5.0)));
41+
ASSERT_EQ(kf::cast<T>(kf::make_constant(int(5.0))), kf::make_vec(T(5.0)));
42+
}
43+
};
44+
45+
REGISTER_TEST_CASE("constant ops tests", constant_ops_tests, int, float, double)
46+
REGISTER_TEST_CASE_GPU("constant ops tests", constant_ops_tests, __half, __nv_bfloat16)
47+
48+
struct constant_eq_tests {
49+
template<typename T>
50+
__host__ __device__ void operator()(generator<T> gen) {
51+
ASSERT(kf::make_constant(T(5.0)) == double(5.0));
52+
ASSERT(kf::make_constant(T(5.0)) == float(5.0));
53+
ASSERT(kf::make_constant(T(5.0)) == int(5.0));
54+
55+
ASSERT(kf::make_constant(double(5.0)) == T(5.0));
56+
ASSERT(kf::make_constant(float(5.0)) == T(5.0));
57+
ASSERT(kf::make_constant(int(5.0)) == T(5.0));
3658
}
3759
};
3860

39-
REGISTER_TEST_CASE("constant tests", constant_tests, int, float, double)
40-
REGISTER_TEST_CASE_GPU("constant tests", constant_tests, __half, __nv_bfloat16)
61+
REGISTER_TEST_CASE("constant eq tests", constant_eq_tests, int, float, double)
62+
REGISTER_TEST_CASE_GPU("constant eq tests", constant_eq_tests, __half, __nv_bfloat16)

tests/unops.cu

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ struct unops_tests {
88
kf::vec<T, N> b;
99

1010
b = -a;
11-
ASSERT(equals(b[I], T(-items[I])) && ...);
11+
ASSERT_EQ_ALL(b[I], T(-items[I]));
1212

1313
b = ~a;
14-
ASSERT(equals(b[I], T(~items[I])) && ...);
14+
ASSERT_EQ_ALL(b[I], T(~items[I]));
1515

1616
b = !a;
17-
ASSERT(equals(b[I], T(!items[I])) && ...);
17+
ASSERT_EQ_ALL(b[I], T(!items[I]));
1818
}
1919
};
2020

@@ -28,58 +28,59 @@ struct unops_float_tests {
2828
kf::vec<T, N> b;
2929

3030
b = -a;
31-
ASSERT(equals(b[I], T(-items[I])) && ...);
31+
ASSERT_EQ_ALL(b[I], T(-items[I]));
3232

3333
b = !a;
34-
ASSERT(equals(b[I], items[I] == 0.0 ? T(1.0) : T(0.0)) && ...);
34+
ASSERT_EQ_ALL(b[I], items[I] == 0.0 ? T(1.0) : T(0.0));
3535

3636
// Ideally, we would test all unary operators, but that would be a lot of work and not that useful since
3737
// all operators are generated by the same macro. Instead, we only check a few of them
3838
if constexpr (is_one_of<T, __half, __nv_bfloat16>) {
3939
// operations on 16-bit numbers are only supported in CC >= 8
4040
#if KERNEL_FLOAT_CUDA_ARCH >= 800
4141
b = sqrt(a);
42-
ASSERT(equals(b[I], hsqrt(T(items[I]))) && ...);
42+
ASSERT_EQ_ALL(b[I], hsqrt(T(items[I])));
4343

4444
b = sin(a);
45-
ASSERT(equals(b[I], hsin(T(items[I]))) && ...);
45+
ASSERT_EQ_ALL(b[I], hsin(T(items[I])));
4646

4747
b = cos(a);
48-
ASSERT(equals(b[I], hcos(T(items[I]))) && ...);
48+
ASSERT_EQ_ALL(b[I], hcos(T(items[I])));
4949

5050
b = log(a);
51-
ASSERT(equals(b[I], hlog(T(items[I]))) && ...);
51+
ASSERT_EQ_ALL(b[I], hlog(T(items[I])));
5252

5353
b = exp(a);
54-
ASSERT(equals(b[I], hexp(T(items[I]))) && ...);
54+
ASSERT_EQ_ALL(b[I], hexp(T(items[I])));
5555

5656
b = rcp(a);
57-
ASSERT(equals(b[I], hrcp(T(items[I]))) && ...);
57+
ASSERT_EQ_ALL(b[I], hrcp(T(items[I])));
5858

5959
b = rsqrt(a);
60-
ASSERT(equals(b[I], hrsqrt(T(items[I]))) && ...);
60+
ASSERT_EQ_ALL(b[I], hrsqrt(T(items[I])));
6161
#endif
6262
} else {
6363
b = sqrt(a);
64-
ASSERT(equals(b[I], sqrt(T(items[I]))) && ...);
64+
ASSERT_EQ_ALL(b[I], sqrt(T(items[I])));
6565

6666
b = sin(a);
67-
ASSERT(equals(b[I], sin(T(items[I]))) && ...);
67+
ASSERT_EQ_ALL(b[I], sin(T(items[I])));
6868

6969
b = cos(a);
70-
ASSERT(equals(b[I], cos(T(items[I]))) && ...);
70+
ASSERT_EQ_ALL(b[I], cos(T(items[I])));
7171

7272
b = log(a);
73-
ASSERT(equals(b[I], log(T(items[I]))) && ...);
73+
ASSERT_EQ_ALL(b[I], log(T(items[I])));
7474

7575
b = exp(a);
76-
ASSERT(equals(b[I], exp(T(items[I]))) && ...);
76+
ASSERT_EQ_ALL(b[I], exp(T(items[I])));
7777

7878
b = rcp(a);
79-
ASSERT(equals(b[I], rcp(T(items[I]))) && ...);
79+
ASSERT_EQ_ALL(b[I], T(1.0 / items[I]));
8080

81+
// seems that rsqrt does not match bitwise on GPU
8182
b = rsqrt(a);
82-
ASSERT(equals(b[I], rsqrt(T(items[I]))) && ...);
83+
ASSERT_APPROX_ALL(b[I], rsqrt(T(items[I])));
8384
}
8485
}
8586
};

0 commit comments

Comments
 (0)