Skip to content

Commit 4c4816d

Browse files
r-barnesfacebook-github-bot
authored andcommitted
Fix CUDA kernel index data type in deeplearning/projects/fairseq-py/fairseq/modules/cuda_utils.cu +10
Summary: CUDA kernel variables matching the type `(thread|block|grid).(Idx|Dim).(x|y|z)` [have the data type `uint`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#built-in-variables). Many programmers mistakenly use implicit casts to turn these data types into `int`. In fact, the [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/) it self is inconsistent and incorrect in its use of data types in programming examples. The result of these implicit casts is that our kernels may give unexpected results when exposed to large datasets, i.e., those exceeding >~2B items. While we now have linters in place to prevent simple mistakes (D71236150), our codebase has many problematic instances. This diff fixes some of them. Reviewed By: dtolnay Differential Revision: D71355350 fbshipit-source-id: a23a7b3ab08dd958db91bd55fe5cad47dd9741f0
1 parent 158f467 commit 4c4816d

4 files changed

Lines changed: 32 additions & 32 deletions

File tree

flashlight/lib/sequence/criterion/cuda/CriterionUtils.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using namespace fl::lib::seq;
2020
*/
2121
__global__ void
2222
batchTargetSizeKernel(int L, int maxSize, const int* _target, int* targetSize) {
23-
int b = blockIdx.x;
23+
auto b = blockIdx.x;
2424
auto target = _target + b * L;
2525

2626
__shared__ int idx;
@@ -31,7 +31,7 @@ batchTargetSizeKernel(int L, int maxSize, const int* _target, int* targetSize) {
3131

3232
__syncthreads();
3333

34-
for (int i = L - 1 - threadIdx.x; i >= 0; i -= blockDim.x) {
34+
for (auto i = L - 1 - threadIdx.x; i >= 0; i -= blockDim.x) {
3535
if (target[i] >= 0) {
3636
atomicMax(&idx, i + 1);
3737
break;
@@ -57,7 +57,7 @@ __global__ void computeScaleKernel(
5757
CriterionScaleMode scaleMode,
5858
const int* targetSize,
5959
Float* scale) {
60-
for (int b = threadIdx.x; b < B; b += blockDim.x) {
60+
for (auto b = threadIdx.x; b < B; b += blockDim.x) {
6161
switch (scaleMode) {
6262
case CriterionScaleMode::NONE:
6363
scale[b] = 1.0;

flashlight/lib/sequence/criterion/cuda/ForceAlignmentCriterion.cu

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,15 @@ __global__ void forwardKernel(
5757
const Float* trans,
5858
Float* _loss,
5959
WorkspacePtrs<Float> ws) {
60-
int b = blockIdx.x;
60+
auto b = blockIdx.x;
6161
auto* alpha = &ws.alpha[b * T * _L];
6262
auto* input = &_input[b * T * N];
6363
auto* target = &_target[b * _L];
6464
auto* transBuf1 = &ws.transBuf1[b * _L];
6565
auto* transBuf2 = &ws.transBuf2[b * _L];
6666
int L = targetSize[b];
6767

68-
for (int i = threadIdx.x; i < L; i += blockDim.x) {
68+
for (auto i = threadIdx.x; i < L; i += blockDim.x) {
6969
alpha[i] = i == 0 ? input[target[0]] : 0;
7070
transBuf1[i] = trans[target[i] * N + target[i]];
7171
transBuf2[i] = i > 0 ? trans[target[i] * N + target[i - 1]] : 0;
@@ -92,7 +92,7 @@ __global__ void forwardKernel(
9292
}
9393
}
9494

95-
for (int i = low + threadIdx.x; i < high; i += blockDim.x) {
95+
for (auto i = low + threadIdx.x; i < high; i += blockDim.x) {
9696
double s1 = alphaPrev[i] + transBuf1[i];
9797
double s2 = alphaPrev[i - 1] + transBuf2[i];
9898
// lse = logSumExp(s1, s2)
@@ -124,7 +124,7 @@ __global__ void backwardKernel(
124124
Float* _inputGrad,
125125
Float* transGrad,
126126
WorkspacePtrs<Float> ws) {
127-
int b = blockIdx.x;
127+
auto b = blockIdx.x;
128128
auto* alpha = &ws.alpha[b * T * _L];
129129
auto* alphaGrad = &ws.alphaGrad[b * T * _L];
130130
auto* inputGrad = &_inputGrad[b * T * N];
@@ -154,7 +154,7 @@ __global__ void backwardKernel(
154154

155155
__syncthreads();
156156

157-
for (int i = low1 + threadIdx.x; i < high1; i += blockDim.x) {
157+
for (auto i = low1 + threadIdx.x; i < high1; i += blockDim.x) {
158158
atomicAdd(&inputCurGrad[target[i]], alphaCurGrad[i]);
159159
}
160160

@@ -170,7 +170,7 @@ __global__ void backwardKernel(
170170
}
171171
}
172172

173-
for (int i = low + threadIdx.x; i < high; i += blockDim.x) {
173+
for (auto i = low + threadIdx.x; i < high; i += blockDim.x) {
174174
double s1 = alphaPrev[i] + transBuf1[i];
175175
double s2 = alphaPrev[i - 1] + transBuf2[i];
176176
// d1, d2 = dLogSumExp(s1, s2)
@@ -198,7 +198,7 @@ __global__ void backwardKernel(
198198
gradScale = grad[b] * ws.scale[b];
199199
}
200200

201-
for (int i = threadIdx.x; i < L; i += blockDim.x) {
201+
for (auto i = threadIdx.x; i < L; i += blockDim.x) {
202202
atomicAdd(&transBatchGrad[target[i] * N + target[i]], transBufGrad1[i]);
203203
if (i > 0) {
204204
atomicAdd(
@@ -208,11 +208,11 @@ __global__ void backwardKernel(
208208

209209
__syncthreads();
210210

211-
for (int i = threadIdx.x; i < T * N; i += blockDim.x) {
211+
for (auto i = threadIdx.x; i < T * N; i += blockDim.x) {
212212
inputGrad[i] *= gradScale;
213213
}
214214

215-
for (int i = threadIdx.x; i < N * N; i += blockDim.x) {
215+
for (auto i = threadIdx.x; i < N * N; i += blockDim.x) {
216216
atomicAdd(&transGrad[i], gradScale * transBatchGrad[i]);
217217
}
218218
}
@@ -228,19 +228,19 @@ __global__ void viterbiPathKernel(
228228
const Float* trans,
229229
int* bestPaths,
230230
WorkspacePtrs<Float> ws) {
231-
int b = blockIdx.x;
231+
auto b = blockIdx.x;
232232
auto* alpha = &ws.alpha[b * T * _L];
233233
auto* input = &_input[b * T * N];
234234
auto* target = &_target[b * _L];
235235
auto* transBuf1 = &ws.transBuf1[b * _L];
236236
auto* transBuf2 = &ws.transBuf2[b * _L];
237237
int L = targetSize[b];
238238

239-
for (int i = threadIdx.x; i < L * T; i += blockDim.x) {
239+
for (auto i = threadIdx.x; i < L * T; i += blockDim.x) {
240240
alpha[i] = i == 0 ? input[target[0]] : -CUDART_INF_F;
241241
}
242242

243-
for (int i = threadIdx.x; i < L; i += blockDim.x) {
243+
for (auto i = threadIdx.x; i < L; i += blockDim.x) {
244244
transBuf1[i] = trans[target[i] * N + target[i]];
245245
transBuf2[i] = i > 0 ? trans[target[i] * N + target[i - 1]] : 0;
246246
}
@@ -270,7 +270,7 @@ __global__ void viterbiPathKernel(
270270
}
271271
}
272272

273-
for (int i = low + threadIdx.x; i < high; i += blockDim.x) {
273+
for (auto i = low + threadIdx.x; i < high; i += blockDim.x) {
274274
double s1 = alphaPrev[i] + transBuf1[i];
275275
double s2 = alphaPrev[i - 1] + transBuf2[i];
276276
alphaCur[i] = inputCur[target[i]] + max(s1, s2);

flashlight/lib/sequence/criterion/cuda/FullConnectionCriterion.cu

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ struct WorkspacePtrs {
4646
template <class Float>
4747
__global__ void
4848
forwardInitial(int T, int N, const Float* input, WorkspacePtrs<Float> ws) {
49-
int b = blockIdx.x;
50-
for (int n = threadIdx.x; n < N; n += blockDim.x) {
49+
auto b = blockIdx.x;
50+
for (auto n = threadIdx.x; n < N; n += blockDim.x) {
5151
int k = b * T * N + n;
5252
ws.alpha[k] = input[k];
5353
}
@@ -84,7 +84,7 @@ __global__ void forwardStep(
8484
__shared__ double maxValue;
8585

8686
double threadMax = -INFINITY;
87-
for (int n = threadIdx.x; n < N; n += blockDim.x) {
87+
for (auto n = threadIdx.x; n < N; n += blockDim.x) {
8888
double val = transBuf[n] = alphaPrev[n] + (Final ? 0 : trans[m * N + n]);
8989
threadMax = val > threadMax ? val : threadMax;
9090
}
@@ -97,7 +97,7 @@ __global__ void forwardStep(
9797
__syncthreads();
9898

9999
double threadSum = 0;
100-
for (int n = threadIdx.x; n < N; n += blockDim.x) {
100+
for (auto n = threadIdx.x; n < N; n += blockDim.x) {
101101
threadSum += exp(transBuf[n] - maxValue);
102102
}
103103

@@ -142,7 +142,7 @@ __global__ void backwardStep1(
142142
__shared__ double sumValue;
143143

144144
double threadMax = -INFINITY;
145-
for (int n = threadIdx.x; n < N; n += blockDim.x) {
145+
for (auto n = threadIdx.x; n < N; n += blockDim.x) {
146146
double val = transBuf[n] = alphaPrev[n] + (Initial ? 0 : trans[m * N + n]);
147147
threadMax = val > threadMax ? val : threadMax;
148148
}
@@ -153,7 +153,7 @@ __global__ void backwardStep1(
153153
}
154154

155155
double threadSum = 0;
156-
for (int n = threadIdx.x; n < N; n += blockDim.x) {
156+
for (auto n = threadIdx.x; n < N; n += blockDim.x) {
157157
transBuf[n] = exp(transBuf[n] - maxValue);
158158
threadSum += transBuf[n];
159159
}
@@ -165,7 +165,7 @@ __global__ void backwardStep1(
165165

166166
__syncthreads();
167167

168-
for (int n = threadIdx.x; n < N; n += blockDim.x) {
168+
for (auto n = threadIdx.x; n < N; n += blockDim.x) {
169169
if (Initial) {
170170
alphaPrevGrad[n] = transBuf[n] / sumValue;
171171
} else {
@@ -181,16 +181,16 @@ __global__ void backwardStep1(
181181
*/
182182
template <class Float>
183183
__global__ void backwardStep2(int T, int N, int t, WorkspacePtrs<Float> ws) {
184-
int b = blockIdx.x / N;
185-
int m = blockIdx.x % N;
184+
auto b = blockIdx.x / N;
185+
auto m = blockIdx.x % N;
186186

187187
auto* alphaPrevGrad = &ws.alphaGrad[b * T * N + (t - 1) * N];
188188

189189
using BlockReduce = cub::BlockReduce<double, kBlockSize>;
190190
__shared__ typename BlockReduce::TempStorage tempStorage;
191191

192192
double threadSum = 0;
193-
for (int n = threadIdx.x; n < N; n += blockDim.x) {
193+
for (auto n = threadIdx.x; n < N; n += blockDim.x) {
194194
threadSum += ws.transBuf[b * N * N + n * N + m];
195195
}
196196

@@ -212,7 +212,7 @@ __global__ void backwardFinal(
212212
Float* _inputGrad,
213213
Float* transGrad,
214214
WorkspacePtrs<Float> ws) {
215-
int b = blockIdx.x;
215+
auto b = blockIdx.x;
216216

217217
auto* alphaGrad = &ws.alphaGrad[b * T * N];
218218
auto* inputGrad = &_inputGrad[b * T * N];
@@ -226,11 +226,11 @@ __global__ void backwardFinal(
226226

227227
__syncthreads();
228228

229-
for (int i = threadIdx.x; i < T * N; i += blockDim.x) {
229+
for (auto i = threadIdx.x; i < T * N; i += blockDim.x) {
230230
inputGrad[i] = gradScale * alphaGrad[i];
231231
}
232232

233-
for (int i = threadIdx.x; i < N * N; i += blockDim.x) {
233+
for (auto i = threadIdx.x; i < N * N; i += blockDim.x) {
234234
atomicAdd(&transGrad[i], gradScale * transBatchGrad[i]);
235235
}
236236
}

flashlight/lib/sequence/criterion/cuda/ViterbiPath.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ struct WorkspacePtrs {
3838
template <class Float>
3939
__global__ void
4040
computeInitial(int T, int N, const Float* input, WorkspacePtrs<Float> ws) {
41-
int b = blockIdx.x;
42-
for (int n = threadIdx.x; n < N; n += blockDim.x) {
41+
auto b = blockIdx.x;
42+
for (auto n = threadIdx.x; n < N; n += blockDim.x) {
4343
ws.alpha[b * 2 * N + n] = input[b * T * N + n];
4444
}
4545
}
@@ -76,7 +76,7 @@ __global__ void computeStep(
7676

7777
cub::KeyValuePair<int, Float> threadMax;
7878
threadMax.value = -INFINITY;
79-
for (int n = threadIdx.x; n < N; n += blockDim.x) {
79+
for (auto n = threadIdx.x; n < N; n += blockDim.x) {
8080
Float val = alphaPrev[n] + (Final ? 0 : trans[m * N + n]);
8181
if (val > threadMax.value) {
8282
threadMax.key = n;

0 commit comments

Comments
 (0)