Skip to content

Commit 6e1ed7d

Browse files
committed
update test
1 parent c2bc6d7 commit 6e1ed7d

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

test/test_torch.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17796,13 +17796,19 @@ def test(probs, replacement):
1779617796
test(y, True)
1779717797
test(z, True)
1779817798

17799-
def test_multinomial_empty(self, device):
17800-
probs = torch.ones(0, 128, device=device)
17801-
num_samples = 64
17799+
def _test_multinomial_empty(self, device, replacement, num_samples):
17800+
probs = torch.ones(0, 3, device=device)
1780217801
expected = torch.empty(0, num_samples, dtype=torch.int64)
17803-
for replacement in (True, False):
17804-
out = torch.multinomial(probs, num_samples=num_samples, replacement=replacement)
17805-
self.assertEqual(out, expected)
17802+
out = torch.multinomial(probs, num_samples=num_samples, replacement=replacement)
17803+
self.assertEqual(out, expected)
17804+
17805+
def test_multinomial_empty_w_replacement(self, device):
17806+
self._test_multinomial_empty(device, True, 1)
17807+
self._test_multinomial_empty(device, True, 2)
17808+
17809+
def test_multinomial_empty_wo_replacement(self, device):
17810+
self._test_multinomial_empty(device, False, 1)
17811+
self._test_multinomial_empty(device, False, 2)
1780617812

1780717813
def _generate_input(self, shape, dtype, device, with_extremal):
1780817814
if shape == ():

0 commit comments

Comments
 (0)