Skip to content

Commit 9b6441e

Browse files
alicanbapaszke
authored andcommitted
Implement Multinomial distribution (#4624)
1 parent 8eded5a commit 9b6441e

File tree

5 files changed

+200
-13
lines changed

5 files changed

+200
-13
lines changed

test/test_distributions.py

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
from common import TestCase, run_tests, set_rng_seed
3232
from torch.autograd import Variable, grad, gradcheck
3333
from torch.distributions import (Bernoulli, Beta, Categorical, Cauchy, Chi2,
34-
Dirichlet, Exponential, Gamma, Gumbel,
35-
Laplace, Normal, OneHotCategorical, Pareto,
34+
Dirichlet, Exponential, Gamma, Gumbel, Laplace,
35+
Normal, OneHotCategorical, Multinomial, Pareto,
3636
StudentT, Uniform, kl_divergence)
3737
from torch.distributions.dirichlet import _Dirichlet_backward
3838
from torch.distributions.constraints import Constraint, is_dependent
@@ -69,6 +69,10 @@
6969
{'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True)},
7070
{'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True)},
7171
]),
72+
Example(Multinomial, [
73+
{'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True), 'total_count': 10},
74+
{'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True), 'total_count': 10},
75+
]),
7276
Example(Cauchy, [
7377
{'loc': 0.0, 'scale': 1.0},
7478
{'loc': Variable(torch.Tensor([0.0])), 'scale': 1.0},
@@ -294,6 +298,53 @@ def test_bernoulli_3d(self):
294298
(2, 5, 2, 3, 5))
295299
self.assertEqual(Bernoulli(p).sample_n(2).size(), (2, 2, 3, 5))
296300

301+
def test_multinomial_1d(self):
302+
total_count = 10
303+
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
304+
self.assertEqual(Multinomial(total_count, p).sample().size(), (3,))
305+
self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3))
306+
self.assertEqual(Multinomial(total_count, p).sample_n(1).size(), (1, 3))
307+
self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
308+
self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
309+
self.assertRaises(NotImplementedError, Multinomial(10, p).rsample)
310+
311+
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
312+
def test_multinomial_1d_log_prob(self):
313+
total_count = 10
314+
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
315+
dist = Multinomial(total_count, probs=p)
316+
x = dist.sample()
317+
log_prob = dist.log_prob(x)
318+
expected = torch.Tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
319+
self.assertEqual(log_prob.data, expected)
320+
321+
dist = Multinomial(total_count, logits=p.log())
322+
x = dist.sample()
323+
log_prob = dist.log_prob(x)
324+
expected = torch.Tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
325+
self.assertEqual(log_prob.data, expected)
326+
327+
def test_multinomial_2d(self):
328+
total_count = 10
329+
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
330+
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
331+
p = Variable(torch.Tensor(probabilities), requires_grad=True)
332+
s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
333+
self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3))
334+
self.assertEqual(Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
335+
self.assertEqual(Multinomial(total_count, p).sample_n(6).size(), (6, 2, 3))
336+
set_rng_seed(0)
337+
self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
338+
p.grad.zero_()
339+
self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
340+
341+
# sample check for extreme value of probs
342+
self.assertEqual(Multinomial(total_count, s).sample().data,
343+
torch.Tensor([[total_count, 0], [0, total_count]]))
344+
345+
# check entropy computation
346+
self.assertRaises(NotImplementedError, Multinomial(10, p).entropy)
347+
297348
def test_categorical_1d(self):
298349
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
299350
# TODO: this should return a 0-dim tensor once we have Scalar support
@@ -1096,13 +1147,16 @@ def test_entropy_shape(self):
10961147
for Dist, params in EXAMPLES:
10971148
for i, param in enumerate(params):
10981149
dist = Dist(**param)
1099-
actual_shape = dist.entropy().size()
1100-
expected_shape = dist._batch_shape
1101-
if not expected_shape:
1102-
expected_shape = torch.Size((1,)) # TODO Remove this once scalars are supported.
1103-
message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
1104-
Dist.__name__, i, len(params), expected_shape, actual_shape)
1105-
self.assertEqual(actual_shape, expected_shape, message=message)
1150+
try:
1151+
actual_shape = dist.entropy().size()
1152+
expected_shape = dist._batch_shape
1153+
if not expected_shape:
1154+
expected_shape = torch.Size((1,)) # TODO Remove this once scalars are supported.
1155+
message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
1156+
Dist.__name__, i, len(params), expected_shape, actual_shape)
1157+
self.assertEqual(actual_shape, expected_shape, message=message)
1158+
except NotImplementedError:
1159+
continue
11061160

11071161
def test_bernoulli_shape_scalar_params(self):
11081162
bernoulli = Bernoulli(0.3)
@@ -1145,6 +1199,16 @@ def test_beta_shape_tensor_params(self):
11451199
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
11461200
self.assertEqual(dist.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
11471201

1202+
def test_multinomial_shape(self):
1203+
dist = Multinomial(10, torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
1204+
self.assertEqual(dist._batch_shape, torch.Size((3,)))
1205+
self.assertEqual(dist._event_shape, torch.Size((2,)))
1206+
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
1207+
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
1208+
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
1209+
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
1210+
self.assertEqual(dist.log_prob(torch.ones(3, 1, 2)).size(), torch.Size((3, 3)))
1211+
11481212
def test_categorical_shape(self):
11491213
dist = Categorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
11501214
self.assertEqual(dist._batch_shape, torch.Size((3,)))
@@ -1375,11 +1439,14 @@ def test_params_contains(self):
13751439
for name, value in param.items():
13761440
if not (torch.is_tensor(value) or isinstance(value, Variable)):
13771441
value = torch.Tensor([value])
1378-
if Dist in (Categorical, OneHotCategorical) and name == 'probs':
1442+
if Dist in (Categorical, OneHotCategorical, Multinomial) and name == 'probs':
13791443
# These distributions accept positive probs, but elsewhere we
13801444
# use a stricter constraint to the simplex.
13811445
value = value / value.sum(-1, True)
1382-
constraint = dist.params[name]
1446+
try:
1447+
constraint = dist.params[name]
1448+
except KeyError:
1449+
continue # ignore optional parameters
13831450
if is_dependent(constraint):
13841451
continue
13851452
message = '{} example {}/{} parameter {} = {}'.format(
@@ -1499,6 +1566,23 @@ def test_categorical_log_prob_with_logits(self):
14991566
log_pdf_prob_0 = categorical.log_prob(Variable(tensor_type([1, 0])))
15001567
self.assertEqual(log_pdf_prob_0.data[0], -float('inf'), allow_inf=True)
15011568

1569+
def test_multinomial_log_prob(self):
1570+
for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
1571+
p = Variable(tensor_type([0, 1]), requires_grad=True)
1572+
s = Variable(tensor_type([0, 10]))
1573+
multinomial = Multinomial(10, p)
1574+
log_pdf = multinomial.log_prob(s)
1575+
self.assertEqual(log_pdf.data[0], 0)
1576+
1577+
def test_multinomial_log_prob_with_logits(self):
1578+
for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
1579+
p = Variable(tensor_type([-float('inf'), 0]), requires_grad=True)
1580+
multinomial = Multinomial(10, logits=p)
1581+
log_pdf_prob_1 = multinomial.log_prob(Variable(tensor_type([0, 10])))
1582+
self.assertEqual(log_pdf_prob_1.data[0], 0)
1583+
log_pdf_prob_0 = multinomial.log_prob(Variable(tensor_type([10, 0])))
1584+
self.assertEqual(log_pdf_prob_0.data[0], -float('inf'), allow_inf=True)
1585+
15021586

15031587
if __name__ == '__main__':
15041588
run_tests()

torch/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .gumbel import Gumbel
4343
from .kl import kl_divergence, register_kl
4444
from .laplace import Laplace
45+
from .multinomial import Multinomial
4546
from .normal import Normal
4647
from .one_hot_categorical import OneHotCategorical
4748
from .pareto import Pareto
@@ -60,6 +61,7 @@
6061
'Gamma',
6162
'Gumbel',
6263
'Laplace',
64+
'Multinomial',
6365
'Normal',
6466
'OneHotCategorical',
6567
'Pareto',

torch/distributions/categorical.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77

88
class Categorical(Distribution):
99
r"""
10-
Creates a categorical distribution parameterized by `probs`.
10+
Creates a categorical distribution parameterized by either `probs` or
11+
`logits` (but not both).
1112
1213
.. note::
13-
It is equivalent to the distribution that ``multinomial()`` samples from.
14+
It is equivalent to the distribution that :func:`torch.multinomial`
15+
samples from.
1416
1517
Samples are integers from `0 ... K-1` where `K` is probs.size(-1).
1618
@@ -30,6 +32,7 @@ class Categorical(Distribution):
3032
3133
Args:
3234
probs (Tensor or Variable): event probabilities
35+
logits (Tensor or Variable): event log probabilities
3336
"""
3437
params = {'probs': constraints.simplex}
3538
has_enumerate_support = True

torch/distributions/constraints.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
'integer_interval',
1111
'interval',
1212
'is_dependent',
13+
'less_than',
1314
'lower_triangular',
1415
'nonnegative_integer',
1516
'positive',
@@ -112,6 +113,17 @@ def check(self, value):
112113
return self.lower_bound <= value
113114

114115

116+
class _LessThan(Constraint):
117+
"""
118+
Constrain to a real half line `[inf, upper_bound]`.
119+
"""
120+
def __init__(self, upper_bound):
121+
self.upper_bound = upper_bound
122+
123+
def check(self, value):
124+
return value <= self.upper_bound
125+
126+
115127
class _Interval(Constraint):
116128
"""
117129
Constrain to a real interval `[lower_bound, upper_bound]`.
@@ -150,6 +162,7 @@ def check(self, value):
150162
real = _Real()
151163
positive = _GreaterThan(0)
152164
greater_than = _GreaterThan
165+
less_than = _LessThan
153166
unit_interval = _Interval(0, 1)
154167
interval = _Interval
155168
simplex = _Simplex()

torch/distributions/multinomial.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import torch
2+
from torch.distributions.distribution import Distribution
3+
from torch.autograd import Variable
4+
from torch.distributions import Categorical
5+
from numbers import Number
6+
from torch.distributions import constraints
7+
from torch.distributions.utils import log_sum_exp, broadcast_all
8+
9+
10+
class Multinomial(Distribution):
11+
r"""
12+
Creates a Multinomial distribution parameterized by `total_count` and
13+
either `probs` or `logits` (but not both). The innermost dimension of
14+
`probs` indexes over categories. All other dimensions index over batches.
15+
16+
Note that `total_count` need not be specified if only :meth:`log_prob` is
17+
called (see example below)
18+
19+
- :meth:`sample` requires a single shared `total_count` for all
20+
parameters and samples.
21+
- :meth:`log_prob` allows different `total_count` for each parameter and
22+
sample.
23+
24+
Example::
25+
26+
>>> m = Multinomial(100, torch.Tensor([ 1, 1, 1, 1]))
27+
>>> x = m.sample() # equal probability of 0, 1, 2, 3
28+
21
29+
24
30+
30
31+
25
32+
[torch.FloatTensor of size 4]]
33+
34+
>>> Multinomial(probs=torch.Tensor([1, 1, 1, 1])).log_prob(x)
35+
-4.1338
36+
[torch.FloatTensor of size 1]
37+
38+
Args:
39+
total_count (int): number of trials
40+
probs (Tensor or Variable): event probabilities
41+
logits (Tensor or Variable): event log probabilities
42+
"""
43+
params = {'logits': constraints.real} # Let logits be the canonical parameterization.
44+
45+
def __init__(self, total_count=1, probs=None, logits=None):
46+
if not isinstance(total_count, Number):
47+
raise NotImplementedError('inhomogeneous total_count is not supported')
48+
self.total_count = total_count
49+
self._categorical = Categorical(probs=probs, logits=logits)
50+
batch_shape = probs.size()[:-1] if probs is not None else logits.size()[:-1]
51+
event_shape = probs.size()[-1:] if probs is not None else logits.size()[-1:]
52+
super(Multinomial, self).__init__(batch_shape, event_shape)
53+
54+
@constraints.dependent_property
55+
def support(self):
56+
return constraints.integer_interval(0, self.total_count)
57+
58+
@property
59+
def logits(self):
60+
return self._categorical.logits
61+
62+
@property
63+
def probs(self):
64+
return self._categorical.probs
65+
66+
def sample(self, sample_shape=torch.Size()):
67+
sample_shape = torch.Size(sample_shape)
68+
samples = self._categorical.sample(torch.Size((self.total_count,)) + sample_shape)
69+
# samples.shape is (total_count, sample_shape, batch_shape), need to change it to
70+
# (sample_shape, batch_shape, total_count)
71+
shifted_idx = list(range(samples.dim()))
72+
shifted_idx.append(shifted_idx.pop(0))
73+
samples = samples.permute(*shifted_idx)
74+
counts = samples.new(self._extended_shape(sample_shape)).zero_()
75+
counts.scatter_add_(-1, samples, torch.ones_like(samples))
76+
return counts.type_as(self.probs)
77+
78+
def log_prob(self, value):
79+
self._validate_log_prob_arg(value)
80+
logits, value = broadcast_all(self.logits.clone(), value)
81+
log_factorial_n = torch.lgamma(value.sum(-1) + 1)
82+
log_factorial_xs = torch.lgamma(value + 1).sum(-1)
83+
logits[(value == 0) & (logits == -float('inf'))] = 0
84+
log_powers = (logits * value).sum(-1)
85+
return log_factorial_n - log_factorial_xs + log_powers

0 commit comments

Comments
 (0)