Skip to content

Commit 98c24fa

Browse files
fritzoezyang
authored andcommitted
Fix broadcasting error in LogNormal and TransformedDistribution (#7269)
1 parent 8325206 commit 98c24fa

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

test/test_distributions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,11 @@ def test_lognormal(self):
11931193
self._gradcheck_log_prob(LogNormal, (mean, 1.0))
11941194
self._gradcheck_log_prob(LogNormal, (0.0, std))
11951195

1196+
# check .log_prob() can broadcast.
1197+
dist = LogNormal(torch.zeros(4), torch.ones(2, 1, 1))
1198+
log_prob = dist.log_prob(torch.ones(3, 1))
1199+
self.assertEqual(log_prob.shape, (2, 3, 4))
1200+
11961201
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
11971202
def test_lognormal_logprob(self):
11981203
mean = torch.randn(5, 1, requires_grad=True)

torch/distributions/transformed_distribution.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ def log_prob(self, value):
7979
y = value
8080
for transform in reversed(self.transforms):
8181
x = transform.inv(y)
82-
log_prob -= _sum_rightmost(transform.log_abs_det_jacobian(x, y),
83-
event_dim - transform.event_dim)
82+
log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
83+
event_dim - transform.event_dim)
8484
y = x
8585

86-
log_prob += _sum_rightmost(self.base_dist.log_prob(y),
87-
event_dim - len(self.base_dist.event_shape))
86+
log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),
87+
event_dim - len(self.base_dist.event_shape))
8888
return log_prob
8989

9090
def _monotonize_cdf(self, value):

0 commit comments

Comments
 (0)