|
31 | 31 | from common import TestCase, run_tests, set_rng_seed |
32 | 32 | from torch.autograd import Variable, grad, gradcheck |
33 | 33 | from torch.distributions import (Bernoulli, Beta, Categorical, Cauchy, Chi2, |
34 | | - Dirichlet, Exponential, Gamma, Gumbel, |
35 | | - Laplace, Normal, OneHotCategorical, Pareto, |
| 34 | + Dirichlet, Exponential, Gamma, Gumbel, Laplace, |
| 35 | + Normal, OneHotCategorical, Multinomial, Pareto, |
36 | 36 | StudentT, Uniform, kl_divergence) |
37 | 37 | from torch.distributions.dirichlet import _Dirichlet_backward |
38 | 38 | from torch.distributions.constraints import Constraint, is_dependent |
|
69 | 69 | {'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True)}, |
70 | 70 | {'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True)}, |
71 | 71 | ]), |
| 72 | + Example(Multinomial, [ |
| 73 | + {'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True), 'total_count': 10}, |
| 74 | + {'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True), 'total_count': 10}, |
| 75 | + ]), |
72 | 76 | Example(Cauchy, [ |
73 | 77 | {'loc': 0.0, 'scale': 1.0}, |
74 | 78 | {'loc': Variable(torch.Tensor([0.0])), 'scale': 1.0}, |
@@ -294,6 +298,53 @@ def test_bernoulli_3d(self): |
294 | 298 | (2, 5, 2, 3, 5)) |
295 | 299 | self.assertEqual(Bernoulli(p).sample_n(2).size(), (2, 2, 3, 5)) |
296 | 300 |
|
| 301 | + def test_multinomial_1d(self): |
| 302 | + total_count = 10 |
| 303 | + p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True) |
| 304 | + self.assertEqual(Multinomial(total_count, p).sample().size(), (3,)) |
| 305 | + self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3)) |
| 306 | + self.assertEqual(Multinomial(total_count, p).sample_n(1).size(), (1, 3)) |
| 307 | + self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p]) |
| 308 | + self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p]) |
| 309 | + self.assertRaises(NotImplementedError, Multinomial(10, p).rsample) |
| 310 | + |
| 311 | + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") |
| 312 | + def test_multinomial_1d_log_prob(self): |
| 313 | + total_count = 10 |
| 314 | + p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True) |
| 315 | + dist = Multinomial(total_count, probs=p) |
| 316 | + x = dist.sample() |
| 317 | + log_prob = dist.log_prob(x) |
| 318 | + expected = torch.Tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy())) |
| 319 | + self.assertEqual(log_prob.data, expected) |
| 320 | + |
| 321 | + dist = Multinomial(total_count, logits=p.log()) |
| 322 | + x = dist.sample() |
| 323 | + log_prob = dist.log_prob(x) |
| 324 | + expected = torch.Tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy())) |
| 325 | + self.assertEqual(log_prob.data, expected) |
| 326 | + |
| 327 | + def test_multinomial_2d(self): |
| 328 | + total_count = 10 |
| 329 | + probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] |
| 330 | + probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] |
| 331 | + p = Variable(torch.Tensor(probabilities), requires_grad=True) |
| 332 | + s = Variable(torch.Tensor(probabilities_1), requires_grad=True) |
| 333 | + self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3)) |
| 334 | + self.assertEqual(Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3)) |
| 335 | + self.assertEqual(Multinomial(total_count, p).sample_n(6).size(), (6, 2, 3)) |
| 336 | + set_rng_seed(0) |
| 337 | + self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p]) |
| 338 | + p.grad.zero_() |
| 339 | + self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p]) |
| 340 | + |
| 341 | + # sample check for extreme value of probs |
| 342 | + self.assertEqual(Multinomial(total_count, s).sample().data, |
| 343 | + torch.Tensor([[total_count, 0], [0, total_count]])) |
| 344 | + |
| 345 | + # check entropy computation |
| 346 | + self.assertRaises(NotImplementedError, Multinomial(10, p).entropy) |
| 347 | + |
297 | 348 | def test_categorical_1d(self): |
298 | 349 | p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True) |
299 | 350 | # TODO: this should return a 0-dim tensor once we have Scalar support |
@@ -1096,13 +1147,16 @@ def test_entropy_shape(self): |
1096 | 1147 | for Dist, params in EXAMPLES: |
1097 | 1148 | for i, param in enumerate(params): |
1098 | 1149 | dist = Dist(**param) |
1099 | | - actual_shape = dist.entropy().size() |
1100 | | - expected_shape = dist._batch_shape |
1101 | | - if not expected_shape: |
1102 | | - expected_shape = torch.Size((1,)) # TODO Remove this once scalars are supported. |
1103 | | - message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format( |
1104 | | - Dist.__name__, i, len(params), expected_shape, actual_shape) |
1105 | | - self.assertEqual(actual_shape, expected_shape, message=message) |
| 1150 | + try: |
| 1151 | + actual_shape = dist.entropy().size() |
| 1152 | + expected_shape = dist._batch_shape |
| 1153 | + if not expected_shape: |
| 1154 | + expected_shape = torch.Size((1,)) # TODO Remove this once scalars are supported. |
| 1155 | + message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format( |
| 1156 | + Dist.__name__, i, len(params), expected_shape, actual_shape) |
| 1157 | + self.assertEqual(actual_shape, expected_shape, message=message) |
| 1158 | + except NotImplementedError: |
| 1159 | + continue |
1106 | 1160 |
|
1107 | 1161 | def test_bernoulli_shape_scalar_params(self): |
1108 | 1162 | bernoulli = Bernoulli(0.3) |
@@ -1145,6 +1199,16 @@ def test_beta_shape_tensor_params(self): |
1145 | 1199 | self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) |
1146 | 1200 | self.assertEqual(dist.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2))) |
1147 | 1201 |
|
| 1202 | + def test_multinomial_shape(self): |
| 1203 | + dist = Multinomial(10, torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) |
| 1204 | + self.assertEqual(dist._batch_shape, torch.Size((3,))) |
| 1205 | + self.assertEqual(dist._event_shape, torch.Size((2,))) |
| 1206 | + self.assertEqual(dist.sample().size(), torch.Size((3, 2))) |
| 1207 | + self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) |
| 1208 | + self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,))) |
| 1209 | + self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) |
| 1210 | + self.assertEqual(dist.log_prob(torch.ones(3, 1, 2)).size(), torch.Size((3, 3))) |
| 1211 | + |
1148 | 1212 | def test_categorical_shape(self): |
1149 | 1213 | dist = Categorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) |
1150 | 1214 | self.assertEqual(dist._batch_shape, torch.Size((3,))) |
@@ -1375,11 +1439,14 @@ def test_params_contains(self): |
1375 | 1439 | for name, value in param.items(): |
1376 | 1440 | if not (torch.is_tensor(value) or isinstance(value, Variable)): |
1377 | 1441 | value = torch.Tensor([value]) |
1378 | | - if Dist in (Categorical, OneHotCategorical) and name == 'probs': |
| 1442 | + if Dist in (Categorical, OneHotCategorical, Multinomial) and name == 'probs': |
1379 | 1443 | # These distributions accept positive probs, but elsewhere we |
1380 | 1444 | # use a stricter constraint to the simplex. |
1381 | 1445 | value = value / value.sum(-1, True) |
1382 | | - constraint = dist.params[name] |
| 1446 | + try: |
| 1447 | + constraint = dist.params[name] |
| 1448 | + except KeyError: |
| 1449 | + continue # ignore optional parameters |
1383 | 1450 | if is_dependent(constraint): |
1384 | 1451 | continue |
1385 | 1452 | message = '{} example {}/{} parameter {} = {}'.format( |
@@ -1499,6 +1566,23 @@ def test_categorical_log_prob_with_logits(self): |
1499 | 1566 | log_pdf_prob_0 = categorical.log_prob(Variable(tensor_type([1, 0]))) |
1500 | 1567 | self.assertEqual(log_pdf_prob_0.data[0], -float('inf'), allow_inf=True) |
1501 | 1568 |
|
| 1569 | + def test_multinomial_log_prob(self): |
| 1570 | + for tensor_type in [torch.FloatTensor, torch.DoubleTensor]: |
| 1571 | + p = Variable(tensor_type([0, 1]), requires_grad=True) |
| 1572 | + s = Variable(tensor_type([0, 10])) |
| 1573 | + multinomial = Multinomial(10, p) |
| 1574 | + log_pdf = multinomial.log_prob(s) |
| 1575 | + self.assertEqual(log_pdf.data[0], 0) |
| 1576 | + |
| 1577 | + def test_multinomial_log_prob_with_logits(self): |
| 1578 | + for tensor_type in [torch.FloatTensor, torch.DoubleTensor]: |
| 1579 | + p = Variable(tensor_type([-float('inf'), 0]), requires_grad=True) |
| 1580 | + multinomial = Multinomial(10, logits=p) |
| 1581 | + log_pdf_prob_1 = multinomial.log_prob(Variable(tensor_type([0, 10]))) |
| 1582 | + self.assertEqual(log_pdf_prob_1.data[0], 0) |
| 1583 | + log_pdf_prob_0 = multinomial.log_prob(Variable(tensor_type([10, 0]))) |
| 1584 | + self.assertEqual(log_pdf_prob_0.data[0], -float('inf'), allow_inf=True) |
| 1585 | + |
1502 | 1586 |
|
1503 | 1587 | if __name__ == '__main__': |
1504 | 1588 | run_tests() |
0 commit comments