Skip to content

Commit 3964253

Browse files
neerajpradapaszke
authored andcommitted
Allowing for vectorized counts in Binomial Distribution (#6720)
1 parent f98b778 commit 3964253

File tree

3 files changed

+88
-27
lines changed

3 files changed

+88
-27
lines changed

test/test_distributions.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ def is_all_nan(tensor):
113113
Example(Binomial, [
114114
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
115115
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': 10},
116+
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10])},
117+
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10, 8])},
118+
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
119+
'total_count': torch.tensor([[10., 8.], [5., 3.]])},
120+
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
121+
'total_count': torch.tensor(0.)},
116122
]),
117123
Example(Multinomial, [
118124
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
@@ -795,6 +801,15 @@ def ref_log_prob(idx, x, log_prob):
795801
logits = probs_to_logits(probs, is_binary=True)
796802
self._check_log_prob(Binomial(total_count, logits=logits), ref_log_prob)
797803

804+
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
805+
def test_binomial_log_prob_vectorized_count(self):
806+
probs = torch.tensor([0.2, 0.7, 0.9])
807+
for total_count, sample in [(torch.tensor([10]), torch.tensor([7., 3., 9.])),
808+
(torch.tensor([1, 2, 10]), torch.tensor([0., 1., 9.]))]:
809+
log_prob = Binomial(total_count, probs).log_prob(sample)
810+
expected = scipy.stats.binom(total_count.cpu().numpy(), probs.cpu().numpy()).logpmf(sample)
811+
self.assertAlmostEqual(log_prob, expected, places=4)
812+
798813
def test_binomial_extreme_vals(self):
799814
total_count = 100
800815
bin0 = Binomial(total_count, 0)
@@ -805,6 +820,28 @@ def test_binomial_extreme_vals(self):
805820
self.assertEqual(bin1.sample(), total_count)
806821
self.assertAlmostEqual(bin1.log_prob(torch.tensor([float(total_count)]))[0], 0, places=3)
807822
self.assertEqual(float(bin1.log_prob(torch.tensor([float(total_count - 1)])).exp()), 0, allow_inf=True)
823+
zero_counts = torch.zeros(torch.Size((2, 2)))
824+
bin2 = Binomial(zero_counts, 1)
825+
self.assertEqual(bin2.sample(), zero_counts)
826+
self.assertEqual(bin2.log_prob(zero_counts), zero_counts)
827+
828+
def test_binomial_vectorized_count(self):
829+
set_rng_seed(0)
830+
total_count = torch.tensor([[4, 7], [3, 8]])
831+
bin0 = Binomial(total_count, torch.tensor(1.))
832+
self.assertEqual(bin0.sample(), total_count)
833+
bin1 = Binomial(total_count, torch.tensor(0.5))
834+
samples = bin1.sample(torch.Size((100000,)))
835+
self.assertTrue((samples <= total_count.type_as(samples)).all())
836+
self.assertEqual(samples.mean(dim=0), bin1.mean, prec=0.02)
837+
self.assertEqual(samples.var(dim=0), bin1.variance, prec=0.02)
838+
839+
def test_binomial_enumerate_support(self):
840+
set_rng_seed(0)
841+
bin0 = Binomial(0, torch.tensor(1.))
842+
self.assertEqual(bin0.enumerate_support(), torch.tensor([0.]))
843+
bin1 = Binomial(torch.tensor(5), torch.tensor(0.5))
844+
self.assertEqual(bin1.enumerate_support(), torch.arange(6))
808845

809846
def test_multinomial_1d(self):
810847
total_count = 10
@@ -1793,9 +1830,8 @@ def test_independent_shape(self):
17931830
self.assertEqual(indep_dist.has_rsample, base_dist.has_rsample)
17941831
if indep_dist.has_rsample:
17951832
self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape)
1796-
if indep_dist.has_enumerate_support:
1797-
self.assertEqual(indep_dist.enumerate_support().shape, base_dist.enumerate_support().shape)
17981833
try:
1834+
self.assertEqual(indep_dist.enumerate_support().shape, base_dist.enumerate_support().shape)
17991835
self.assertEqual(indep_dist.mean.shape, base_dist.mean.shape)
18001836
except NotImplementedError:
18011837
pass
@@ -2301,6 +2337,15 @@ def test_binomial_shape(self):
23012337
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
23022338
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
23032339

2340+
def test_binomial_shape_vectorized_n(self):
2341+
dist = Binomial(torch.tensor([[10, 3, 1], [4, 8, 4]]), torch.tensor([0.6, 0.3, 0.1]))
2342+
self.assertEqual(dist._batch_shape, torch.Size((2, 3)))
2343+
self.assertEqual(dist._event_shape, torch.Size(()))
2344+
self.assertEqual(dist.sample().size(), torch.Size((2, 3)))
2345+
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2, 3)))
2346+
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
2347+
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
2348+
23042349
def test_multinomial_shape(self):
23052350
dist = Multinomial(10, torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
23062351
self.assertEqual(dist._batch_shape, torch.Size((3,)))
@@ -2562,6 +2607,8 @@ def __init__(self, probs):
25622607
# e.g. bernoulli[1] varies row-wise; that way we test all param pairs.
25632608
bernoulli = pairwise(Bernoulli, [0.1, 0.2, 0.6, 0.9])
25642609
binomial30 = pairwise(Binomial30, [0.1, 0.2, 0.6, 0.9])
2610+
binomial_vectorized_count = (Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])),
2611+
Binomial(torch.tensor([3, 4]), torch.tensor([0.5, 0.8])))
25652612
beta = pairwise(Beta, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
25662613
categorical = pairwise(Categorical, [[0.4, 0.3, 0.3],
25672614
[0.2, 0.7, 0.1],
@@ -2607,6 +2654,7 @@ def __init__(self, probs):
26072654
(beta, gamma),
26082655
(beta, normal),
26092656
(binomial30, binomial30),
2657+
(binomial_vectorized_count, binomial_vectorized_count),
26102658
(categorical, categorical),
26112659
(chi2, chi2),
26122660
(chi2, exponential),
@@ -2654,6 +2702,8 @@ def __init__(self, probs):
26542702
(Beta(1, 2), Uniform(0.25, 0.75)),
26552703
(Beta(1, 2), Pareto(1, 2)),
26562704
(Binomial(31, 0.7), Binomial(30, 0.3)),
2705+
(Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])),
2706+
Binomial(torch.tensor([2, 3]), torch.tensor([0.5, 0.8]))),
26572707
(Chi2(1), Beta(2, 3)),
26582708
(Chi2(1), Pareto(2, 3)),
26592709
(Chi2(1), Uniform(-2, 3)),

torch/distributions/binomial.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
11
from numbers import Number
22
import torch
3-
import math
43
from torch.distributions import constraints
54
from torch.distributions.distribution import Distribution
65
from torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs
7-
from torch.distributions.utils import clamp_probs
86

97

108
class Binomial(Distribution):
119
r"""
1210
Creates a Binomial distribution parameterized by `total_count` and
13-
either `probs` or `logits` (but not both).
14-
15-
- Requires a single shared `total_count` for all
16-
parameters and samples.
11+
either `probs` or `logits` (but not both). `total_count` must be
12+
broadcastable with `probs`/`logits`.
1713
1814
Example::
1915
@@ -25,26 +21,32 @@ class Binomial(Distribution):
2521
100
2622
[torch.FloatTensor of size 4]]
2723
24+
>>> m = Binomial(torch.Tensor([[5.], [10.]]), torch.Tensor([0.5, 0.8]))
25+
>>> x = m.sample()
26+
4 5
27+
7 6
28+
[torch.FloatTensor of size (2,2)]
29+
2830
Args:
29-
total_count (int): number of Bernoulli trials
31+
total_count (int or Tensor): number of Bernoulli trials
3032
probs (Tensor): Event probabilities
3133
logits (Tensor): Event log-odds
3234
"""
33-
arg_constraints = {'probs': constraints.unit_interval}
35+
arg_constraints = {'total_count': constraints.nonnegative_integer,
36+
'probs': constraints.unit_interval}
3437
has_enumerate_support = True
3538

3639
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
37-
if not isinstance(total_count, Number):
38-
raise NotImplementedError('inhomogeneous total_count is not supported')
39-
self.total_count = total_count
4040
if (probs is None) == (logits is None):
4141
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
4242
if probs is not None:
43-
is_scalar = isinstance(probs, Number)
44-
self.probs, = broadcast_all(probs)
43+
self.total_count, self.probs, = broadcast_all(total_count, probs)
44+
self.total_count = self.total_count.type_as(self.logits)
45+
is_scalar = isinstance(self.probs, Number)
4546
else:
46-
is_scalar = isinstance(logits, Number)
47-
self.logits, = broadcast_all(logits)
47+
self.total_count, self.logits, = broadcast_all(total_count, logits)
48+
self.total_count = self.total_count.type_as(self.logits)
49+
is_scalar = isinstance(self.logits, Number)
4850

4951
self._param = self.probs if probs is not None else self.logits
5052
if is_scalar:
@@ -81,14 +83,20 @@ def param_shape(self):
8183
return self._param.size()
8284

8385
def sample(self, sample_shape=torch.Size()):
84-
shape = self._extended_shape(sample_shape) + (self.total_count,)
8586
with torch.no_grad():
86-
return torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)).sum(dim=-1)
87+
max_count = max(int(self.total_count.max()), 1)
88+
shape = self._extended_shape(sample_shape) + (max_count,)
89+
bernoullis = torch.bernoulli(self.probs.unsqueeze(-1).expand(shape))
90+
if self.total_count.min() != max_count:
91+
arange = torch.arange(max_count, out=self.total_count.new_empty(max_count))
92+
mask = arange >= self.total_count.unsqueeze(-1)
93+
bernoullis.masked_fill_(mask, 0.)
94+
return bernoullis.sum(dim=-1)
8795

8896
def log_prob(self, value):
8997
if self._validate_args:
9098
self._validate_sample(value)
91-
log_factorial_n = math.lgamma(self.total_count + 1)
99+
log_factorial_n = torch.lgamma(self.total_count + 1)
92100
log_factorial_k = torch.lgamma(value + 1)
93101
log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
94102
max_val = (-self.logits).clamp(min=0.0)
@@ -98,8 +106,11 @@ def log_prob(self, value):
98106
self.total_count * torch.log1p((self.logits + 2 * max_val).exp()))
99107

100108
def enumerate_support(self):
101-
values = self._new((self.total_count,))
102-
torch.arange(self.total_count, out=values.data)
109+
total_count = int(self.total_count.max())
110+
if not self.total_count.min() == total_count:
111+
raise NotImplementedError("Inhomogeneous total count not supported by `enumerate_support`.")
112+
values = self._new(1 + total_count,)
113+
torch.arange(1 + total_count, out=values)
103114
values = values.view((-1,) + (1,) * len(self._batch_shape))
104115
values = values.expand((-1,) + self._batch_shape)
105116
return values

torch/distributions/kl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,12 @@ def _kl_beta_beta(p, q):
198198
def _kl_binomial_binomial(p, q):
199199
# from https://math.stackexchange.com/questions/2214993/
200200
# kullback-leibler-divergence-for-binomial-distributions-p-and-q
201-
if p.total_count > q.total_count:
202-
return _infinite_like(p.probs)
203-
elif p.total_count == q.total_count:
204-
return p.total_count * (p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p())
205-
else:
201+
if (p.total_count < q.total_count).any():
206202
raise NotImplementedError('KL between Binomials where q.total_count > p.total_count is not implemented')
203+
kl = p.total_count * (p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p())
204+
inf_idxs = p.total_count > q.total_count
205+
kl[inf_idxs] = _infinite_like(kl[inf_idxs])
206+
return kl
207207

208208

209209
@register_kl(Categorical, Categorical)

0 commit comments

Comments
 (0)