Skip to content

a strange torch.no_grad behaviour when used with lazy_property from distributions #10207

@serhii-havrylov

Description

@serhii-havrylov

Recently I encountered a weird bug, that can be reproduced with the next snippet of code:

import torch
from torch.nn import functional as F
from torch.distributions.one_hot_categorical import OneHotCategorical


scores = torch.tensor([[1, 4, 2],
                       [8, 4, 4]], dtype=torch.float32, requires_grad=True)
probs = F.softmax(scores, dim=-1)
cat = OneHotCategorical(probs)

print(cat.logits.requires_grad)
with torch.no_grad():
    print(cat.logits.requires_grad)

sample = cat.sample()
log_prob = cat.log_prob(sample)
loss = log_prob.sum()
loss.backward()

If one removes the print function right before the no_grad context manager there is going to be a runtime error when calling backward: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

It turns out that calling lazy property cat.logits within the no_grad context manager will create an attribute which doesn't require grad and if one uses cat.logits outside no_grad it still will not require grad, even though it should. You can see this strange interplay of lazy_property and no_grad in this example:

import torch
from torch.distributions.utils import lazy_property


class MyScores:
    def __init__(self, score):
        self.score = score

    @lazy_property
    def negative_score(self):
        return -self.score


scores = torch.tensor([[1, 4, 2],
                       [8, 4, 4]], dtype=torch.float32, requires_grad=True)
my_scores = MyScores(scores)

with torch.no_grad():
    print(my_scores.negative_score.requires_grad)

loss = my_scores.negative_score.sum()
loss.backward()

This is not necessarily a bug, because the expected behaviour of the no_grad that whatever computations are created inside this context manager they are not going to require grad. So calling my_scores.negative_score within no_grad will create negative_score attribute which will not require grad as expected. The thing is that it is somewhat counterintuitive that my_scores.negative_score still doesn't require grad outside of the context manager. A user doesn't really know that it is a lazy property and expects it to have the same require_grad property as in the probs variable. Thus, the usage of the lazy_property within distributions can lead to the somewhat obscure bugs.

In my case I was sampling some noise, and used logits to obtain correct dtype and device:

with torch.no_grad():
     uniforms = T.empty(sample_shape, dtype=cat.logits.dtype, device=cat.logits.device).uniform_()

It was really surprising to see RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn later when calling loss.backward()

PyTorch version is 0.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions