Skip to content

Commit 0394c5a

Browse files
kshitij12345facebook-github-bot
authored andcommitted
[fix] torch.multinomial : fix for 0 size dim (#43775)
Summary: Fixes #43768 TO-DO: * [x] Add test Pull Request resolved: #43775 Reviewed By: ZolotukhinM Differential Revision: D23421979 Pulled By: ngimel fbshipit-source-id: 949fcdd30f18d17ae1c372fa6ca6a0b8d0d538ce
1 parent 3c8b1d7 commit 0394c5a

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

aten/src/ATen/native/Distributions.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ Tensor& multinomial_out(Tensor& result, const Tensor& self, int64_t n_sample, bo
454454
if (self.dim() > 1) {
455455
int64_t n_dist = self.size(-2);
456456
result.resize_({n_dist, n_sample});
457+
if (n_dist == 0) { return result; };
457458
} else {
458459
result.resize_({n_sample});
459460
}

aten/src/ATen/native/cuda/MultinomialKernel.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n
395395
}
396396
});
397397
398+
AT_CUDA_CHECK(cudaGetLastError());
399+
398400
if (inputSize == 1) {
399401
result.resize_({n_sample});
400402
}

test/test_torch.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17942,13 +17942,19 @@ def test(probs, replacement):
1794217942
test(y, True)
1794317943
test(z, True)
1794417944

17945-
def test_multinomial_empty(self, device):
17946-
probs = torch.ones(0, 3)
17947-
num_samples = 1
17945+
def _test_multinomial_empty(self, device, replacement, num_samples):
17946+
probs = torch.ones(0, 3, device=device)
1794817947
expected = torch.empty(0, num_samples, dtype=torch.int64)
17949-
for replacement in (True, False):
17950-
out = torch.multinomial(probs, num_samples=num_samples, replacement=replacement)
17951-
self.assertEqual(out, expected)
17948+
out = torch.multinomial(probs, num_samples=num_samples, replacement=replacement)
17949+
self.assertEqual(out, expected)
17950+
17951+
def test_multinomial_empty_w_replacement(self, device):
17952+
self._test_multinomial_empty(device, True, 1)
17953+
self._test_multinomial_empty(device, True, 2)
17954+
17955+
def test_multinomial_empty_wo_replacement(self, device):
17956+
self._test_multinomial_empty(device, False, 1)
17957+
self._test_multinomial_empty(device, False, 2)
1795217958

1795317959
def _generate_input(self, shape, dtype, device, with_extremal):
1795417960
if shape == ():

0 commit comments

Comments
 (0)