Skip to content

[distributions] rsample().detach() and sample() yields different gradients #7705

@evgenii-nikishin

Description

@evgenii-nikishin

I have a MultivariateNormal distribution with loc defined as output of neural net (given input) and diagonal covariance matrix with trainable parameters (but does not depending on some input).
If I sample via distr.rsample().detach() and optimise sum of log_probs, .backward() provides correct gradients w.r.t. both loc and cov matrix params. But if I sample via distr.sample(), .backward() sets None gradients for cov matrix params.

Here is a minimal reproducing code:

import torch
from torch.distributions import MultivariateNormal

torch.manual_seed(124)
inp = torch.randn(3)
W = torch.randn((2, 3), requires_grad=True)
loc = W@inp
cov = torch.tensor([[1., 0.], [0., 2.]], requires_grad=True)
distr = MultivariateNormal(loc, cov)

y = distr.rsample().detach()  # difference here
loss = distr.log_prob(y)
loss.backward()
W.grad, cov.grad

yields

(tensor([[-0.1991, -1.0776, -0.6339],
         [ 0.3455,  1.8700,  1.1000]]), 
 tensor([[-0.2678,  0.0000],
         [-0.8057,  0.4491]]))

while

import torch
from torch.distributions import MultivariateNormal

torch.manual_seed(124)
inp = torch.randn(3)
W = torch.randn((2, 3), requires_grad=True)
loc = W@inp
cov = torch.tensor([[1., 0.], [0., 2.]], requires_grad=True)
distr = MultivariateNormal(loc, cov)

x = distr.sample()  # difference here
loss = distr.log_prob(x)
loss.backward()
W.grad, cov.grad

yields

(tensor([[-0.1991, -1.0776, -0.6339],
         [ 0.3455,  1.8700,  1.1000]]), 
 None)

I am observing this problem both on macOS and Ubuntu systems.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions