@@ -17796,13 +17796,19 @@ def test(probs, replacement):
1779617796 test(y, True)
1779717797 test(z, True)
1779817798
17799- def test_multinomial_empty(self, device):
17800- probs = torch.ones(0, 128, device=device)
17801- num_samples = 64
17799+ def _test_multinomial_empty(self, device, replacement, num_samples):
17800+ probs = torch.ones(0, 3, device=device)
1780217801 expected = torch.empty(0, num_samples, dtype=torch.int64)
17803- for replacement in (True, False):
17804- out = torch.multinomial(probs, num_samples=num_samples, replacement=replacement)
17805- self.assertEqual(out, expected)
17802+ out = torch.multinomial(probs, num_samples=num_samples, replacement=replacement)
17803+ self.assertEqual(out, expected)
17804+
17805+ def test_multinomial_empty_w_replacement(self, device):
17806+ self._test_multinomial_empty(device, True, 1)
17807+ self._test_multinomial_empty(device, True, 2)
17808+
17809+ def test_multinomial_empty_wo_replacement(self, device):
17810+ self._test_multinomial_empty(device, False, 1)
17811+ self._test_multinomial_empty(device, False, 2)
1780617812
1780717813 def _generate_input(self, shape, dtype, device, with_extremal):
1780817814 if shape == ():
0 commit comments