-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
I have a MultivariateNormal distribution with loc defined as output of neural net (given input) and diagonal covariance matrix with trainable parameters (but does not depending on some input).
If I sample via distr.rsample().detach() and optimise sum of log_probs, .backward() provides correct gradients w.r.t. both loc and cov matrix params. But if I sample via distr.sample(), .backward() sets None gradients for cov matrix params.
Here is a minimal reproducing code:
import torch
from torch.distributions import MultivariateNormal
torch.manual_seed(124)
inp = torch.randn(3)
W = torch.randn((2, 3), requires_grad=True)
loc = W@inp
cov = torch.tensor([[1., 0.], [0., 2.]], requires_grad=True)
distr = MultivariateNormal(loc, cov)
y = distr.rsample().detach() # difference here
loss = distr.log_prob(y)
loss.backward()
W.grad, cov.grad
yields
(tensor([[-0.1991, -1.0776, -0.6339],
[ 0.3455, 1.8700, 1.1000]]),
tensor([[-0.2678, 0.0000],
[-0.8057, 0.4491]]))
while
import torch
from torch.distributions import MultivariateNormal
torch.manual_seed(124)
inp = torch.randn(3)
W = torch.randn((2, 3), requires_grad=True)
loc = W@inp
cov = torch.tensor([[1., 0.], [0., 2.]], requires_grad=True)
distr = MultivariateNormal(loc, cov)
x = distr.sample() # difference here
loss = distr.log_prob(x)
loss.backward()
W.grad, cov.grad
yields
(tensor([[-0.1991, -1.0776, -0.6339],
[ 0.3455, 1.8700, 1.1000]]),
None)
I am observing this problem both on macOS and Ubuntu systems.
Metadata
Metadata
Assignees
Labels
No labels