Skip to content

Commit 390b7af

Browse files
ssnlsoumith
authored andcommitted
Fix CUDA Multinomial checks (#4009)
1 parent 43dd631 commit 390b7af

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

aten/src/TH/generic/THTensorRandom.c

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
265265
THArgCheckWithCleanup(n_sample > 0,
266266
THCleanup(if (start_dim == 1) THTensor_(resize1d)(prob_dist, n_categories);),
267267
2,
268-
"cannot sample n_sample < 0 samples");
268+
"cannot sample n_sample <= 0 samples");
269269

270270
if (!with_replacement)
271271
{
@@ -285,12 +285,18 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
285285
{
286286
/* Get normalized cumulative distribution from prob distribution */
287287
double sum = 0;
288+
double val;
288289
for (j=0; j<n_categories; j++)
289290
{
290-
sum += THStorage_(get)( \
291+
val = THStorage_(get)( \
291292
prob_dist->storage, \
292293
prob_dist->storageOffset+i*prob_dist->stride[0]+j*prob_dist->stride[1] \
293294
);
295+
THArgCheckWithCleanup((val >= 0),
296+
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(resize1d)(prob_dist, n_categories);),
297+
2,
298+
"invalid multinomial distribution (encountering probability entry < 0)");
299+
sum += val;
294300
THDoubleStorage_set(
295301
cum_dist->storage, \
296302
cum_dist->storageOffset+j*cum_dist->stride[0], \

aten/src/THC/THCTensorRandom.cuh

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ multinomialAliasDrawKernel(int size, int64_t *output, int64_t *J, T *q, int64_t
4646
T bern_uniform = bernoulli[idx];
4747
int _mask = (int) THCNumerics<T>::lt(bern_uniform, q[rand_ind]);
4848
output[idx] = J[rand_ind]*(1 -_mask) + (rand_ind+1L) * _mask;
49-
}
49+
}
5050
}
5151

5252
template <typename T>
@@ -94,15 +94,19 @@ template <typename T>
9494
__global__ void renormRowsL1(T* dist, long rows, long cols) {
9595
extern __shared__ unsigned char my_smem[];
9696
T *smem = reinterpret_cast<T *>(my_smem);
97-
97+
T zero = ScalarConvert<int, T>::to(0);
98+
T val;
9899
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
99100
T sum = ScalarConvert<int, T>::to(0);
100101
for (int64_t col = threadIdx.x; col < cols; col += blockDim.x) {
101-
sum = THCNumerics<T>::add(sum, dist[row * cols + col]);
102+
val = dist[row * cols + col];
103+
assert(THCNumerics<T>::ge(val, zero));
104+
sum = THCNumerics<T>::add(sum, val);
102105
}
103106

104-
sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd<T, T>(), ScalarConvert<int, T>::to(0));
107+
sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd<T, T>(), zero);
105108
if (threadIdx.x == 0) {
109+
assert(THCNumerics<T>::gt(sum, zero));
106110
smem[0] = sum;
107111
}
108112
__syncthreads();
@@ -169,10 +173,11 @@ sampleMultinomialOnce(int64_t* dest,
169173
// Each block handles one distribution
170174
// First pass, find the total sum of the distribution
171175
AccT sum = accZero;
176+
T val;
172177
for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) {
173-
sum = THCNumerics<AccT>::add(
174-
sum,
175-
ScalarConvert<T, AccT>::to(dist[curDist * categories + cat]));
178+
val = dist[curDist * categories + cat];
179+
assert(THCNumerics<T>::ge(val, zero));
180+
sum = THCNumerics<AccT>::add(sum, ScalarConvert<T, AccT>::to(val));
176181
}
177182

178183
// threadIdx.x == 0 has the sum value from this
@@ -182,6 +187,7 @@ sampleMultinomialOnce(int64_t* dest,
182187
if (threadIdx.x == 0) {
183188
// Make sure the sum of our distribution didn't overflow
184189
assert(!isinf(sum));
190+
assert(THCNumerics<AccT>::gt(sum, accZero));
185191

186192
asmem[0] = sum;
187193
smem[0] = sampled[curDist];

0 commit comments

Comments
 (0)