Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions aten/src/TH/generic/THTensorRandom.c
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
THArgCheckWithCleanup(n_sample > 0,
THCleanup(if (start_dim == 1) THTensor_(resize1d)(prob_dist, n_categories);),
2,
"cannot sample n_sample < 0 samples");
"cannot sample n_sample <= 0 samples");

if (!with_replacement)
{
Expand All @@ -273,12 +273,18 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
{
/* Get normalized cumulative distribution from prob distribution */
double sum = 0;
double val;
for (j=0; j<n_categories; j++)
{
sum += THStorage_(get)( \
val = THStorage_(get)( \
prob_dist->storage, \
prob_dist->storageOffset+i*prob_dist->stride[0]+j*prob_dist->stride[1] \
);
THArgCheckWithCleanup((val >= 0),
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(resize1d)(prob_dist, n_categories);),
2,
"invalid multinomial distribution (encountering probability entry < 0)");
sum += val;
THDoubleStorage_set(
cum_dist->storage, \
cum_dist->storageOffset+j*cum_dist->stride[0], \
Expand Down
20 changes: 13 additions & 7 deletions aten/src/THC/THCTensorRandom.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ multinomialAliasDrawKernel(int size, int64_t *output, int64_t *J, T *q, int64_t
T bern_uniform = bernoulli[idx];
int _mask = (int) THCNumerics<T>::lt(bern_uniform, q[rand_ind]);
output[idx] = J[rand_ind]*(1 -_mask) + (rand_ind+1L) * _mask;
}
}
}

template <typename T>
Expand Down Expand Up @@ -94,15 +94,19 @@ template <typename T>
__global__ void renormRowsL1(T* dist, long rows, long cols) {
extern __shared__ unsigned char my_smem[];
T *smem = reinterpret_cast<T *>(my_smem);

T zero = ScalarConvert<int, T>::to(0);
T val;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
T sum = ScalarConvert<int, T>::to(0);
for (int64_t col = threadIdx.x; col < cols; col += blockDim.x) {
sum = THCNumerics<T>::add(sum, dist[row * cols + col]);
val = dist[row * cols + col];
assert(THCNumerics<T>::ge(val, zero));
sum = THCNumerics<T>::add(sum, val);
}

sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd<T, T>(), ScalarConvert<int, T>::to(0));
sum = reduceBlock(smem, blockDim.x, sum, ReduceAdd<T, T>(), zero);
if (threadIdx.x == 0) {
assert(THCNumerics<T>::gt(sum, zero));

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

smem[0] = sum;
}
__syncthreads();
Expand Down Expand Up @@ -169,10 +173,11 @@ sampleMultinomialOnce(int64_t* dest,
// Each block handles one distribution
// First pass, find the total sum of the distribution
AccT sum = accZero;
T val;
for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) {
sum = THCNumerics<AccT>::add(
sum,
ScalarConvert<T, AccT>::to(dist[curDist * categories + cat]));
val = dist[curDist * categories + cat];
assert(THCNumerics<T>::ge(val, zero));
sum = THCNumerics<AccT>::add(sum, ScalarConvert<T, AccT>::to(val));
}

// threadIdx.x == 0 has the sum value from this
Expand All @@ -182,6 +187,7 @@ sampleMultinomialOnce(int64_t* dest,
if (threadIdx.x == 0) {
// Make sure the sum of our distribution didn't overflow
assert(!isinf(sum));
assert(THCNumerics<AccT>::gt(sum, accZero));

This comment was marked as off-topic.


asmem[0] = sum;
smem[0] = sampled[curDist];
Expand Down