Skip to content

Commit 927d608

Browse files
committed
Widening fix
1 parent b2d6102 commit 927d608

1 file changed

Lines changed: 17 additions & 5 deletions

File tree

src/mapreduce.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,19 @@
66

77
# Widen sub-word types to avoid shared memory corruption on Intel GPUs.
88
# Writing 1/2-byte values to local memory can clobber adjacent bytes.
9-
@inline _widen_type(::Type{T}) where T = sizeof(T) < 4 ? Int32 : T
9+
# Only applies to integer/boolean types where `%` conversion is valid.
10+
@inline _widen_type(::Type{Bool}) = Int32
11+
@inline _widen_type(::Type{Int8}) = Int32
12+
@inline _widen_type(::Type{UInt8}) = Int32
13+
@inline _widen_type(::Type{Int16}) = Int32
14+
@inline _widen_type(::Type{UInt16}) = Int32
15+
@inline _widen_type(::Type{T}) where T = T
16+
17+
# Dispatch-based conversions so the compiler never generates `%` for non-integer types
18+
@inline _to_wide(val, ::Type{W}) where W = val % W
19+
@inline _to_wide(val::T, ::Type{T}) where T = val
20+
@inline _from_wide(val, ::Type{T}) where T = val % T
21+
@inline _from_wide(val::T, ::Type{T}) where T = val
1022

1123
# Reduce a value across a group, using local memory for communication
1224
@inline function reduce_group(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems}
@@ -16,7 +28,7 @@
1628
# use a wider type for shared memory to avoid sub-word corruption
1729
W = _widen_type(T)
1830
shared = oneLocalArray(W, (maxitems,))
19-
@inbounds shared[item] = val % W
31+
@inbounds shared[item] = _to_wide(val, W)
2032

2133
# perform a reduction
2234
d = 1
@@ -25,18 +37,18 @@
2537
index = 2 * d * (item-1) + 1
2638
@inbounds if index <= items
2739
other_val = if index + d <= items
28-
shared[index+d] % T
40+
_from_wide(shared[index+d], T)
2941
else
3042
neutral
3143
end
32-
shared[index] = op(shared[index] % T, other_val) % W
44+
shared[index] = _to_wide(op(_from_wide(shared[index], T), other_val), W)
3345
end
3446
d *= 2
3547
end
3648

3749
# load the final value on the first item
3850
if item == 1
39-
val = @inbounds shared[item] % T
51+
val = @inbounds _from_wide(shared[item], T)
4052
end
4153

4254
return val

0 commit comments

Comments
 (0)