Skip to content

Commit f8cab38

Browse files
fehiepsifacebook-github-bot
authored andcommitted
Address precision matrix instability of MVN distribution (#21366)
Summary: Currently, when the input of MVN is precision matrix, we take inverse to convert the result to covariance matrix. This, however, will easily make the covariance matrix not positive definite, hence will trigger a cholesky error. For example, ``` import torch torch.manual_seed(0) x = torch.randn(10) P = torch.exp(-(x - x.unsqueeze(-1)) ** 2) torch.distributions.MultivariateNormal(loc=torch.ones(10), precision_matrix=P) ``` will trigger `RuntimeError: cholesky_cpu: U(8,8) is zero, singular U.` This PR uses some math tricks ([ref](https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril)) to only take inverse of a triangular matrix, hence increase the stability. cc fritzo, neerajprad , SsnL Pull Request resolved: #21366 Differential Revision: D15696972 Pulled By: ezyang fbshipit-source-id: cec13f7dfdbd06dee94b8bed8ff0b3e720c7a188
1 parent 8ece538 commit f8cab38

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

test/test_distributions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,6 +1781,11 @@ def gradcheck_func(samples, mu, sigma, prec, scale_tril):
17811781
multivariate_normal_log_prob_gradcheck(mean, None, None, scale_tril)
17821782
multivariate_normal_log_prob_gradcheck(mean_no_batch, None, None, scale_tril_batched)
17831783

1784+
def test_multivariate_normal_stable_with_precision_matrix(self):
1785+
x = torch.randn(10)
1786+
P = torch.exp(-(x - x.unsqueeze(-1)) ** 2) # RBF kernel
1787+
MultivariateNormal(x.new_zeros(10), precision_matrix=P)
1788+
17841789
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
17851790
def test_multivariate_normal_log_prob(self):
17861791
mean = torch.randn(3, requires_grad=True)

torch/distributions/multivariate_normal.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,15 @@ def _batch_mahalanobis(bL, bx):
6666
return reshaped_M.reshape(bx_batch_shape)
6767

6868

69+
def _precision_to_scale_tril(P):
70+
# Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
71+
Lf = torch.cholesky(torch.flip(P, (-2, -1)))
72+
L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
73+
L = torch.triangular_solve(torch.eye(P.shape[-1], dtype=P.dtype, device=P.device),
74+
L_inv, upper=False)[0]
75+
return L
76+
77+
6978
class MultivariateNormal(Distribution):
7079
r"""
7180
Creates a multivariate normal (also called Gaussian) distribution
@@ -136,10 +145,10 @@ def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tri
136145

137146
if scale_tril is not None:
138147
self._unbroadcasted_scale_tril = scale_tril
139-
else:
140-
if precision_matrix is not None:
141-
self.covariance_matrix = torch.inverse(precision_matrix).expand_as(loc_)
142-
self._unbroadcasted_scale_tril = torch.cholesky(self.covariance_matrix)
148+
elif covariance_matrix is not None:
149+
self._unbroadcasted_scale_tril = torch.cholesky(covariance_matrix)
150+
else: # precision_matrix is not None
151+
self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
143152

144153
def expand(self, batch_shape, _instance=None):
145154
new = self._get_checked_instance(MultivariateNormal, _instance)

0 commit comments

Comments
 (0)