Skip to content

Commit 8734b17

Browse files
Ailing Zhangfacebook-github-bot
authored andcommitted
Multinomial raise error (#12490)
Summary: Fixes #12260 #2896 ``` torch.multinomial(torch.FloatTensor([0, 1, 0, 0]), 3, replacement=False) ``` The old behavior is that we return `0` after we run out of postive categories. Now we raise an error based on discussion in the issue thread. - Add testcase for cpu & cuda case, in cuda case `n_samples=1` is a simple special case, so we test against `n_sample=2` instead. Pull Request resolved: #12490 Differential Revision: D10278794 Pulled By: ailzhang fbshipit-source-id: d04de7a60f60d0c0d648b975db3f3961fcf42db1
1 parent b89a3b5 commit 8734b17

File tree

4 files changed

+24
-22
lines changed

4 files changed

+24
-22
lines changed

aten/src/TH/generic/THTensorRandom.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
285285
/* Get normalized cumulative distribution from prob distribution */
286286
double sum = 0;
287287
double val;
288+
int n_zeros = 0;
288289
for (j=0; j<n_categories; j++)
289290
{
290291
val = THStorage_(get)( \
@@ -300,6 +301,9 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
300301
2,
301302
"invalid multinomial distribution (encountering probability entry = infinity or NaN)");
302303
sum += val;
304+
if (val == 0) {
305+
n_zeros += 1;
306+
}
303307
THDoubleStorage_set(
304308
THTensor_getStoragePtr(cum_dist), \
305309
cum_dist->storage_offset()+j*cum_dist->stride(0), \
@@ -310,6 +314,10 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
310314
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(squeeze1d)(prob_dist, prob_dist, 0);),
311315
2,
312316
"invalid multinomial distribution (sum of probabilities <= 0)");
317+
THArgCheckWithCleanup((with_replacement || (n_categories - n_zeros >= n_sample)),
318+
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(squeeze1d)(prob_dist, prob_dist, 0);),
319+
2,
320+
"invalid multinomial distribution (with replacement=False, not enough non-negative category to sample)");
313321
/* normalize cumulative probability distribution so that last val is 1
314322
i.e. doesn't assume original prob_dist row sums to one */
315323
if ( (sum > 0) || ( ( sum < 1.00001) && (sum > 0.99999) ) )

aten/src/THC/THCTensorRandom.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
#include <curand_kernel.h>
99

10-
#define MAX_NUM_BLOCKS 200
10+
#define MAX_NUM_BLOCKS 200
1111
#define BLOCK_SIZE 256
1212
/* Separate kernel because curand_log_normal gets extra parameters. */
1313

@@ -126,6 +126,8 @@ __device__ int binarySearchForMultinomial(T* dist,
126126
T val) {
127127
int start = 0;
128128
int end = size;
129+
// dist[size - 1] = 0 => all zero prob dist
130+
assert(THCNumerics<T>::gt(dist[size - 1], 0));
129131

130132
while (end - start > 0) {
131133
int mid = start + (end - start) / 2;

test/test_cuda.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,17 +1594,6 @@ def test_multinomial(self):
15941594
r = torch.multinomial(p, 1)
15951595
self.assertNotEqual(r.min().item(), 0)
15961596

1597-
# multinomial without repeat but with less nonzero
1598-
# elements than draws
1599-
# the intention currently is to return 0 for those
1600-
# and match CPU behaviour, see issue #9062
1601-
p = torch.zeros(1, 5, device="cuda")
1602-
p[:, 1] = 1
1603-
r = torch.multinomial(p, 2, replacement=False)
1604-
expected = torch.zeros(1, 2, device="cuda", dtype=torch.long)
1605-
expected[:, 0] = 1
1606-
self.assertEqual(r, expected)
1607-
16081597
@staticmethod
16091598
def mute():
16101599
os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno())
@@ -1621,7 +1610,7 @@ def _spawn_method(self, method, arg):
16211610
def _test_multinomial_invalid_probs_cuda(probs):
16221611
try:
16231612
with torch.random.fork_rng(devices=[0]):
1624-
torch.multinomial(probs.to('cuda'), 1)
1613+
torch.multinomial(probs.to('cuda'), 2)
16251614
torch.cuda.synchronize()
16261615
return False # Should not be reached
16271616
except RuntimeError as e:
@@ -1635,10 +1624,11 @@ def _test_multinomial_invalid_probs_cuda(probs):
16351624
but we need it for creating another process with CUDA")
16361625
def test_multinomial_invalid_probs_cuda(self):
16371626
test_method = TestCuda._test_multinomial_invalid_probs_cuda
1638-
self._spawn_method(test_method, torch.Tensor([0, -1]))
1639-
self._spawn_method(test_method, torch.Tensor([0, inf]))
1640-
self._spawn_method(test_method, torch.Tensor([0, -inf]))
1641-
self._spawn_method(test_method, torch.Tensor([0, nan]))
1627+
self._spawn_method(test_method, torch.Tensor([1, -1, 1]))
1628+
self._spawn_method(test_method, torch.Tensor([1, inf, 1]))
1629+
self._spawn_method(test_method, torch.Tensor([1, -inf, 1]))
1630+
self._spawn_method(test_method, torch.Tensor([1, 1, nan]))
1631+
self._spawn_method(test_method, torch.Tensor([0, 1, 0]))
16421632

16431633
@skipIfRocm
16441634
def test_broadcast(self):

test/test_torch.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,7 +2851,8 @@ def _spawn_method(self, method, arg):
28512851
@staticmethod
28522852
def _test_multinomial_invalid_probs(probs):
28532853
try:
2854-
torch.multinomial(probs.to('cpu'), 1)
2854+
# n_sample = 1 is a special case, test n_sample=2 which is more general
2855+
torch.multinomial(probs.to('cpu'), 2)
28552856
return False # Should not be reached
28562857
except RuntimeError as e:
28572858
return 'invalid multinomial distribution' in str(e)
@@ -2864,10 +2865,11 @@ def _test_multinomial_invalid_probs(probs):
28642865
but we need it for for testing failure case for CPU RNG on Windows")
28652866
def test_multinomial_invalid_probs(self):
28662867
test_method = TestTorch._test_multinomial_invalid_probs
2867-
self._spawn_method(test_method, torch.Tensor([0, -1]))
2868-
self._spawn_method(test_method, torch.Tensor([0, inf]))
2869-
self._spawn_method(test_method, torch.Tensor([0, -inf]))
2870-
self._spawn_method(test_method, torch.Tensor([0, nan]))
2868+
self._spawn_method(test_method, torch.Tensor([1, -1, 1]))
2869+
self._spawn_method(test_method, torch.Tensor([1, inf, 1]))
2870+
self._spawn_method(test_method, torch.Tensor([1, -inf, 1]))
2871+
self._spawn_method(test_method, torch.Tensor([1, 1, nan]))
2872+
self._spawn_method(test_method, torch.Tensor([0, 1, 0]))
28712873

28722874
@suppress_warnings
28732875
def test_range(self):

0 commit comments

Comments
 (0)