Commit f8cab38
Address precision matrix instability of MVN distribution (#21366)
Summary:
Currently, when the input of MVN is precision matrix, we take inverse to convert the result to covariance matrix. This, however, will easily make the covariance matrix not positive definite, hence will trigger a cholesky error.
For example,
```
import torch
torch.manual_seed(0)
x = torch.randn(10)
P = torch.exp(-(x - x.unsqueeze(-1)) ** 2)
torch.distributions.MultivariateNormal(loc=torch.ones(10), precision_matrix=P)
```
will trigger `RuntimeError: cholesky_cpu: U(8,8) is zero, singular U.`
This PR uses some math tricks ([ref](https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril)) to only take inverse of a triangular matrix, hence increase the stability.
cc fritzo, neerajprad , SsnL
Pull Request resolved: #21366
Differential Revision: D15696972
Pulled By: ezyang
fbshipit-source-id: cec13f7dfdbd06dee94b8bed8ff0b3e720c7a1881 parent 8ece538 commit f8cab38
File tree
2 files changed
+18
-4
lines changed- test
- torch/distributions
2 files changed
+18
-4
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1781 | 1781 | | |
1782 | 1782 | | |
1783 | 1783 | | |
| 1784 | + | |
| 1785 | + | |
| 1786 | + | |
| 1787 | + | |
| 1788 | + | |
1784 | 1789 | | |
1785 | 1790 | | |
1786 | 1791 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
66 | 66 | | |
67 | 67 | | |
68 | 68 | | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
69 | 78 | | |
70 | 79 | | |
71 | 80 | | |
| |||
136 | 145 | | |
137 | 146 | | |
138 | 147 | | |
139 | | - | |
140 | | - | |
141 | | - | |
142 | | - | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
143 | 152 | | |
144 | 153 | | |
145 | 154 | | |
| |||
0 commit comments