Skip to content

Commit 98879d5

Browse files
ssnlsoumith
authored andcommitted
Fix CUDA Multinomial checks (#4009)
1 parent e98af60 commit 98879d5

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

torch/lib/TH/generic/THTensorRandom.c

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
247247
THArgCheckWithCleanup(n_sample > 0,
248248
THCleanup(if (start_dim == 1) THTensor_(resize1d)(prob_dist, n_categories);),
249249
2,
250-
"cannot sample n_sample < 0 samples");
250+
"cannot sample n_sample <= 0 samples");
251251

252252
if (!with_replacement)
253253
{
@@ -267,12 +267,18 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
267267
{
268268
/* Get normalized cumulative distribution from prob distribution */
269269
double sum = 0;
270+
double val;
270271
for (j=0; j<n_categories; j++)
271272
{
272-
sum += THStorage_(get)( \
273+
val = THStorage_(get)( \
273274
prob_dist->storage, \
274275
prob_dist->storageOffset+i*prob_dist->stride[0]+j*prob_dist->stride[1] \
275276
);
277+
THArgCheckWithCleanup((val >= 0),
278+
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(resize1d)(prob_dist, n_categories);),
279+
2,
280+
"invalid multinomial distribution (encountering probability entry < 0)");
281+
sum += val;
276282
THDoubleStorage_set(
277283
cum_dist->storage, \
278284
cum_dist->storageOffset+j*cum_dist->stride[0], \

torch/lib/THC/THCTensorRandom.cuh

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ multinomialAliasDrawKernel(int size, int64_t *output, int64_t *J, T *q, int64_t
4646
T bern_uniform = bernoulli[idx];
4747
int _mask = (int) THCNumerics<T>::lt(bern_uniform, q[rand_ind]);
4848
output[idx] = J[rand_ind]*(1 -_mask) + (rand_ind+1L) * _mask;
49-
}
49+
}
5050
}
5151

5252
template <typename T>
@@ -94,15 +94,19 @@ template <typename T>
9494
__global__ void renormRowsL1(T* dist, long rows, long cols) {
9595
extern __shared__ unsigned char my_smem[];
9696
T *smem = reinterpret_cast<T *>(my_smem);
97-
97+
T zero = ScalarConvert<int, T>::to(0);
98+
T val;
9899
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
99100
T sum = ScalarConvert<int, T>::to(0);
100101
for (int64_t col = threadIdx.x; col < cols; col += blockDim.x) {
101-
sum = THCNumerics<T>::add(sum, dist[row * cols + col]);
102+
val = dist[row * cols + col];
103+
assert(THCNumerics<T>::ge(val, zero));
104+
sum = THCNumerics<T>::add(sum, val);
102105
}
103106

104-
sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd<T, T>(), ScalarConvert<int, T>::to(0));
107+
sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd<T, T>(), zero);
105108
if (threadIdx.x == 0) {
109+
assert(THCNumerics<T>::gt(sum, zero));
106110
smem[0] = sum;
107111
}
108112
__syncthreads();
@@ -169,10 +173,11 @@ sampleMultinomialOnce(int64_t* dest,
169173
// Each block handles one distribution
170174
// First pass, find the total sum of the distribution
171175
AccT sum = accZero;
176+
T val;
172177
for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) {
173-
sum = THCNumerics<AccT>::add(
174-
sum,
175-
ScalarConvert<T, AccT>::to(dist[curDist * categories + cat]));
178+
val = dist[curDist * categories + cat];
179+
assert(THCNumerics<T>::ge(val, zero));
180+
sum = THCNumerics<AccT>::add(sum, ScalarConvert<T, AccT>::to(val));
176181
}
177182

178183
// threadIdx.x == 0 has the sum value from this
@@ -182,6 +187,7 @@ sampleMultinomialOnce(int64_t* dest,
182187
if (threadIdx.x == 0) {
183188
// Make sure the sum of our distribution didn't overflow
184189
assert(!isinf(sum));
190+
assert(THCNumerics<AccT>::gt(sum, accZero));
185191

186192
asmem[0] = sum;
187193
smem[0] = sampled[curDist];

0 commit comments

Comments
 (0)