Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions aten/src/TH/generic/THStorage.c
Original file line number Diff line number Diff line change
Expand Up @@ -155,28 +155,28 @@ void THStorage_(resize)(THStorage *storage, ptrdiff_t size)
real *old_data = storage->data;
ptrdiff_t old_size = storage->size;
if (size == 0) {
storage->data = NULL;
storage->data = NULL;
} else {
storage->data = storage->allocator->malloc(
storage->allocatorContext,
sizeof(real)*size);
storage->data = storage->allocator->malloc(
storage->allocatorContext,
sizeof(real)*size);
}
storage->size = size;
if (old_data != NULL) {
ptrdiff_t copy_size = old_size;
if (storage->size < copy_size) {
copy_size = storage->size;
}
if (copy_size > 0) {
memcpy(storage->data, old_data, sizeof(real)*copy_size);
}
storage->allocator->free(storage->allocatorContext, old_data);
ptrdiff_t copy_size = old_size;
if (storage->size < copy_size) {
copy_size = storage->size;
}
if (copy_size > 0) {
memcpy(storage->data, old_data, sizeof(real)*copy_size);
}
storage->allocator->free(storage->allocatorContext, old_data);
}
} else {
storage->data = storage->allocator->realloc(
storage->allocatorContext,
storage->data,
sizeof(real)*size);
storage->allocatorContext,
storage->data,
sizeof(real)*size);
storage->size = size;
}
} else {
Expand Down
3 changes: 3 additions & 0 deletions aten/src/TH/generic/THTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ TH_API THTensor *THTensor_(newExpand)(THTensor *tensor, THLongStorage *size);
TH_API void THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *size);
TH_API void THTensor_(expandNd)(THTensor **rets, THTensor **ops, int count);

// resize* methods simply resize the storage. So they may not retain the current data at current indices.
// This is especially likely to happen when the tensor is not contiguous. In general, if you still need the
// values, unless you are doing some size and stride tricks, do not use resize*.
TH_API void THTensor_(resize)(THTensor *tensor, THLongStorage *size, THLongStorage *stride);
TH_API void THTensor_(resizeAs)(THTensor *tensor, THTensor *src);
TH_API void THTensor_(resizeNd)(THTensor *tensor, int nDimension, int64_t *size, int64_t *stride);
Expand Down
14 changes: 7 additions & 7 deletions aten/src/TH/generic/THTensorRandom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,23 +284,23 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso

if (start_dim == 1)
{
THTensor_(resize2d)(prob_dist, 1, THTensor_(size)(prob_dist, 0));
THTensor_(unsqueeze1d)(prob_dist, prob_dist, 0);
}

n_dist = THTensor_(size)(prob_dist, 0);
n_categories = THTensor_(size)(prob_dist, 1);

THArgCheckWithCleanup(n_sample > 0,
THCleanup(if (start_dim == 1) THTensor_(resize1d)(prob_dist, n_categories);),
THCleanup(if (start_dim == 1) THTensor_(squeeze1d)(prob_dist, prob_dist, 0);),
2,
"cannot sample n_sample <= 0 samples");

if (!with_replacement)
{
THArgCheckWithCleanup((!with_replacement) && (n_sample <= n_categories),
THCleanup(if (start_dim == 1) THTensor_(resize1d)(prob_dist, n_categories);),
THCleanup(if (start_dim == 1) THTensor_(squeeze1d)(prob_dist, prob_dist, 0);),
2,
"cannot sample n_sample > prob_dist:size(1) samples without replacement");
"cannot sample n_sample > prob_dist.size(1) samples without replacement");
}

/* cumulative probability distribution vector */
Expand All @@ -321,7 +321,7 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
prob_dist->storageOffset+i*prob_dist->stride[0]+j*prob_dist->stride[1] \
);
THArgCheckWithCleanup((val >= 0),
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(resize1d)(prob_dist, n_categories);),
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(squeeze1d)(prob_dist, prob_dist, 0);),
2,
"invalid multinomial distribution (encountering probability entry < 0)");
sum += val;
Expand All @@ -332,7 +332,7 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
);
}
THArgCheckWithCleanup((sum > 0),
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(resize1d)(prob_dist, n_categories);),
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(squeeze1d)(prob_dist, prob_dist, 0);),
2,
"invalid multinomial distribution (sum of probabilities <= 0)");
/* normalize cumulative probability distribution so that last val is 1
Expand Down Expand Up @@ -434,7 +434,7 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
if (start_dim == 1)
{
THLongTensor_resize1d(self, n_sample);
THTensor_(resize1d)(prob_dist, n_categories);
THTensor_(squeeze1d)(prob_dist, prob_dist, 0);
}
}
#endif
Expand Down
11 changes: 7 additions & 4 deletions aten/src/THC/THCTensorRandom.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ sampleMultinomialOnce(int64_t* dest,
int64_t distributions,
int categories,
T* sampled,
T* dist) {
T* dist,
int stride_dist, // dist->stride[0]
int stride_categories // dist->stride[1]
) {
extern __shared__ unsigned char my_smem[];
__shared__ bool found;

Expand All @@ -175,7 +178,7 @@ sampleMultinomialOnce(int64_t* dest,
AccT sum = accZero;
T val;
for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) {
val = dist[curDist * categories + cat];
val = dist[curDist * categories * stride_dist + cat * stride_categories];
assert(THCNumerics<T>::ge(val, zero));
sum = THCNumerics<AccT>::add(sum, ScalarConvert<T, AccT>::to(val));
}
Expand Down Expand Up @@ -218,7 +221,7 @@ sampleMultinomialOnce(int64_t* dest,
AccT val =
cat < categories ?
THCNumerics<AccT>::div(
ScalarConvert<T, AccT>::to(dist[curDist * categories + cat]),
ScalarConvert<T, AccT>::to(dist[curDist * categories * stride_dist + cat * stride_categories]),
sum) :
accZero;

Expand Down Expand Up @@ -272,7 +275,7 @@ sampleMultinomialOnce(int64_t* dest,
// where the distribution is non-zero. This is obviously terribly inefficient, but due to the
// rarity in which this occurs, this should not be an issue.
for (int cat = categories - 1; cat >= 0; --cat) {
if (THCNumerics<T>::gt(dist[curDist * categories + cat], zero)) {
if (THCNumerics<T>::gt(dist[curDist * categories * stride_dist + cat * stride_categories], zero)) {
dest[curDist] = cat + TH_INDEX_BASE;
break;
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/THC/generic/THCTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ THC_API THCTensor *THCTensor_(newExpand)(THCState *state, THCTensor *tensor, THL
THC_API void THCTensor_(expand)(THCState *state, THCTensor *r, THCTensor *tensor, THLongStorage *sizes);
THC_API void THCTensor_(expandNd)(THCState *state, THCTensor **rets, THCTensor **ops, int count);

// resize* methods simply resize the storage. So they may not retain the current data at current indices.
// This is especially likely to happen when the tensor is not contiguous. In general, if you still need the
// values, unless you are doing some size and stride tricks, do not use resize*.
THC_API void THCTensor_(resize)(THCState *state, THCTensor *tensor, THLongStorage *size, THLongStorage *stride);
THC_API void THCTensor_(resizeAs)(THCState *state, THCTensor *tensor, THCTensor *src);
THC_API void THCTensor_(resize1d)(THCState *state, THCTensor *tensor, int64_t size0_);
Expand Down
40 changes: 20 additions & 20 deletions aten/src/THC/generic/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,14 @@ THC_API void THCTensor_(multinomial)(struct THCState *state,
"replacement");
}

// It is possible that prob_dist is non-contiguous
THCTensor* probDistContig =
THCTensor_(newContiguous)(state, prob_dist);
int free_prob_dist = 0;

// Restructure data for 2d
if (inputSize == 1) {
THCTensor_(resize2d)(state, probDistContig, 1, numCategories);
THCTensor *temp = THCTensor_(new)(state);
THCTensor_(unsqueeze1d)(state, temp, prob_dist, 0);
prob_dist = temp;
free_prob_dist = 1;
}

THCudaLongTensor_resize2d(state, self, numDist, n_sample);
Expand All @@ -181,7 +182,7 @@ THC_API void THCTensor_(multinomial)(struct THCState *state,
int maxShared = props->sharedMemPerBlock;
int requiredShared = (numCategories < maxThreads ? numCategories : maxThreads)
* (sizeof(real) * sizeof(accreal));

if (n_sample == 1 && maxShared >= requiredShared) {
// Optimized allocation-free implementation
// To exploit greater parallelism for the sampling, generate the
Expand All @@ -201,19 +202,22 @@ THC_API void THCTensor_(multinomial)(struct THCState *state,
numDist,
numCategories,
THCTensor_(data)(state, sampled),
THCTensor_(data)(state, probDistContig));
THCTensor_(data)(state, prob_dist),
THCTensor_(stride)(state, prob_dist, 0),
THCTensor_(stride)(state, prob_dist, 1)
);
THCTensor_(free)(state, sampled);
} else {
// Generic, slow implementation with memory allocations

// For sampling without replacement, we modify the distribution
// for subsequent samples in this space
THCTensor* origDist = THCTensor_(new)(state);
THCTensor_(resizeAs)(state, origDist, probDistContig);
THCTensor_(copy)(state, origDist, probDistContig);
THCTensor_(resizeAs)(state, origDist, prob_dist);
THCTensor_(copy)(state, origDist, prob_dist);

THCTensor* normDist = THCTensor_(new)(state);
THCTensor_(resizeAs)(state, normDist, probDistContig);
THCTensor_(resizeAs)(state, normDist, prob_dist);

THCTensor* prefixSum = THCTensor_(new)(state);

Expand Down Expand Up @@ -289,14 +293,10 @@ THC_API void THCTensor_(multinomial)(struct THCState *state,
// Revert data restructuring based on input sizes
if (inputSize == 1) {
THCudaLongTensor_resize1d(state, self, n_sample);

// Unfortunately, if prob_dist is contiguous already,
// newContiguous is not a private copy, so we have to restructure
// this too, so as to not affect prob_dist
THCTensor_(resize1d)(state, probDistContig, numCategories);
}

THCTensor_(free)(state, probDistContig);
if (free_prob_dist) {
THCTensor_(free)(state, prob_dist);
}
}

THC_API void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_probs, THCudaLongTensor *_J, THCTensor *_q){
Expand All @@ -308,10 +308,10 @@ THC_API void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_prob
THCudaLongTensor *larger = THCudaLongTensor_newWithSize1d(state, inputsize);
THCudaLongTensor *smaller_short = THCudaLongTensor_newWithSize1d(state, inputsize);
THCudaLongTensor *larger_short = THCudaLongTensor_newWithSize1d(state, inputsize);

THCudaLongTensor_resize1d(state, _J, inputsize);
THCTensor_(resize1d)(state, _q, inputsize);

real one = ScalarConvert<int64_t, real>::to(1);
int inputBlockDim = THCCeilDiv((int)inputsize + BLOCK_SIZE - 1, BLOCK_SIZE);
aliasMultinomialFilter
Expand All @@ -325,7 +325,7 @@ THC_API void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_prob
THCudaLongTensor_data(state, larger_short),
one, inputsize
);

THCudaLongTensor_nonzero(state, smaller_short, smaller);
THCudaLongTensor_nonzero(state, larger_short, larger);
int h_large_c = THCudaLongTensor_nElement(state, larger_short);
Expand All @@ -347,7 +347,7 @@ THC_API void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_prob
THCudaLongTensor_data(state, _J),
inputsize, q_max
);

THCudaLongTensor_free(state, smaller);
THCudaLongTensor_free(state, larger);
THCudaLongTensor_free(state, smaller_short);
Expand Down
3 changes: 3 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,9 @@ def test_view(self):
def test_stft(self):
TestTorch._test_stft(self, lambda t: t.cuda())

def test_multinomial(self):
TestTorch._test_multinomial(self, torch.cuda.FloatTensor)

def test_broadcast(self):
TestTorch._test_broadcast(self, lambda t: t.cuda())

Expand Down
21 changes: 21 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,16 @@ def test_multinomial_shape(self):
self.assertEqual(dist.log_prob(Variable(torch.ones(3, 1, 2))).size(), torch.Size((3, 3)))

def test_categorical_shape(self):
# unbatched
dist = Categorical(variable([0.6, 0.3, 0.1]))
self.assertEqual(dist._batch_shape, torch.Size(()))
self.assertEqual(dist._event_shape, torch.Size(()))
self.assertEqual(dist.sample().size(), torch.Size(SCALAR_SHAPE))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2,)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
self.assertEqual(dist.log_prob(Variable(torch.ones(3, 1))).size(), torch.Size((3, 1)))
# batched
dist = Categorical(variable([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size(()))
Expand All @@ -1631,6 +1641,17 @@ def test_categorical_shape(self):
self.assertEqual(dist.log_prob(Variable(torch.ones(3, 1))).size(), torch.Size((3, 3)))

def test_one_hot_categorical_shape(self):
# unbatched
dist = OneHotCategorical(variable([0.6, 0.3, 0.1]))
self.assertEqual(dist._batch_shape, torch.Size(()))
self.assertEqual(dist._event_shape, torch.Size((3,)))
self.assertEqual(dist.sample().size(), torch.Size((3,)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2,)))
self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,)))
self.assertEqual(dist.log_prob(Variable(torch.ones(3, 3))).size(), torch.Size((3,)))
# batched
dist = OneHotCategorical(variable([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size((2,)))
Expand Down
Loading