Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions aten/src/TH/generic/THTensorRandom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
/* Get normalized cumulative distribution from prob distribution */
double sum = 0;
double val;
int n_zeros = 0;
for (j=0; j<n_categories; j++)
{
val = THStorage_(get)( \
Expand All @@ -300,6 +301,9 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
2,
"invalid multinomial distribution (encountering probability entry = infinity or NaN)");
sum += val;
if (val == 0) {
n_zeros += 1;
}
THDoubleStorage_set(
THTensor_getStoragePtr(cum_dist), \
cum_dist->storage_offset()+j*cum_dist->stride(0), \
Expand All @@ -310,6 +314,10 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(squeeze1d)(prob_dist, prob_dist, 0);),
2,
"invalid multinomial distribution (sum of probabilities <= 0)");
THArgCheckWithCleanup((with_replacement || (n_categories - n_zeros >= n_sample)),
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(squeeze1d)(prob_dist, prob_dist, 0);),

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

2,
"invalid multinomial distribution (with replacement=False, not enough non-negative category to sample)");
/* normalize cumulative probability distribution so that last val is 1
i.e. doesn't assume original prob_dist row sums to one */
if ( (sum > 0) || ( ( sum < 1.00001) && (sum > 0.99999) ) )
Expand Down
4 changes: 3 additions & 1 deletion aten/src/THC/THCTensorRandom.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#include <curand_kernel.h>

#define MAX_NUM_BLOCKS 200
#define MAX_NUM_BLOCKS 200
#define BLOCK_SIZE 256
/* Separate kernel because curand_log_normal gets extra parameters. */

Expand Down Expand Up @@ -126,6 +126,8 @@ __device__ int binarySearchForMultinomial(T* dist,
T val) {
int start = 0;
int end = size;
// dist[size - 1] = 0 => all zero prob dist
assert(THCNumerics<T>::gt(dist[size - 1], 0));

while (end - start > 0) {
int mid = start + (end - start) / 2;
Expand Down
22 changes: 6 additions & 16 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,17 +1594,6 @@ def test_multinomial(self):
r = torch.multinomial(p, 1)
self.assertNotEqual(r.min().item(), 0)

# multinomial without repeat but with less nonzero
# elements than draws
# the intention currently is to return 0 for those
# and match CPU behaviour, see issue #9062
p = torch.zeros(1, 5, device="cuda")
p[:, 1] = 1
r = torch.multinomial(p, 2, replacement=False)
expected = torch.zeros(1, 2, device="cuda", dtype=torch.long)
expected[:, 0] = 1
self.assertEqual(r, expected)

@staticmethod
def mute():
os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno())
Expand All @@ -1621,7 +1610,7 @@ def _spawn_method(self, method, arg):
def _test_multinomial_invalid_probs_cuda(probs):
try:
with torch.random.fork_rng(devices=[0]):
torch.multinomial(probs.to('cuda'), 1)
torch.multinomial(probs.to('cuda'), 2)
torch.cuda.synchronize()
return False # Should not be reached
except RuntimeError as e:
Expand All @@ -1635,10 +1624,11 @@ def _test_multinomial_invalid_probs_cuda(probs):
but we need it for creating another process with CUDA")
def test_multinomial_invalid_probs_cuda(self):
test_method = TestCuda._test_multinomial_invalid_probs_cuda
self._spawn_method(test_method, torch.Tensor([0, -1]))
self._spawn_method(test_method, torch.Tensor([0, inf]))
self._spawn_method(test_method, torch.Tensor([0, -inf]))
self._spawn_method(test_method, torch.Tensor([0, nan]))
self._spawn_method(test_method, torch.Tensor([1, -1, 1]))
self._spawn_method(test_method, torch.Tensor([1, inf, 1]))
self._spawn_method(test_method, torch.Tensor([1, -inf, 1]))
self._spawn_method(test_method, torch.Tensor([1, 1, nan]))
self._spawn_method(test_method, torch.Tensor([0, 1, 0]))

@skipIfRocm
def test_broadcast(self):
Expand Down
12 changes: 7 additions & 5 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2851,7 +2851,8 @@ def _spawn_method(self, method, arg):
@staticmethod
def _test_multinomial_invalid_probs(probs):
try:
torch.multinomial(probs.to('cpu'), 1)
# n_sample = 1 is a special case, test n_sample=2 which is more general
torch.multinomial(probs.to('cpu'), 2)
return False # Should not be reached
except RuntimeError as e:
return 'invalid multinomial distribution' in str(e)
Expand All @@ -2864,10 +2865,11 @@ def _test_multinomial_invalid_probs(probs):
but we need it for for testing failure case for CPU RNG on Windows")
def test_multinomial_invalid_probs(self):
test_method = TestTorch._test_multinomial_invalid_probs
self._spawn_method(test_method, torch.Tensor([0, -1]))
self._spawn_method(test_method, torch.Tensor([0, inf]))
self._spawn_method(test_method, torch.Tensor([0, -inf]))
self._spawn_method(test_method, torch.Tensor([0, nan]))
self._spawn_method(test_method, torch.Tensor([1, -1, 1]))
self._spawn_method(test_method, torch.Tensor([1, inf, 1]))
self._spawn_method(test_method, torch.Tensor([1, -inf, 1]))
self._spawn_method(test_method, torch.Tensor([1, 1, nan]))
self._spawn_method(test_method, torch.Tensor([0, 1, 0]))

@suppress_warnings
def test_range(self):
Expand Down