Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,11 @@ def test_lognormal(self):
self._gradcheck_log_prob(LogNormal, (mean, 1.0))
self._gradcheck_log_prob(LogNormal, (0.0, std))

# check .log_prob() can broadcast.
dist = LogNormal(torch.zeros(4), torch.ones(2, 1, 1))
log_prob = dist.log_prob(torch.ones(3, 1))
self.assertEqual(log_prob.shape, (2, 3, 4))

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_lognormal_logprob(self):
mean = torch.randn(5, 1, requires_grad=True)
Expand Down
8 changes: 4 additions & 4 deletions torch/distributions/transformed_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def log_prob(self, value):
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
log_prob -= _sum_rightmost(transform.log_abs_det_jacobian(x, y),
event_dim - transform.event_dim)
log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
event_dim - transform.event_dim)
y = x

log_prob += _sum_rightmost(self.base_dist.log_prob(y),
event_dim - len(self.base_dist.event_shape))
log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),
event_dim - len(self.base_dist.event_shape))
return log_prob

def _monotonize_cdf(self, value):
Expand Down