-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Recently I encountered a weird bug, that can be reproduced with the next snippet of code:
import torch
from torch.nn import functional as F
from torch.distributions.one_hot_categorical import OneHotCategorical
scores = torch.tensor([[1, 4, 2],
[8, 4, 4]], dtype=torch.float32, requires_grad=True)
probs = F.softmax(scores, dim=-1)
cat = OneHotCategorical(probs)
print(cat.logits.requires_grad)
with torch.no_grad():
print(cat.logits.requires_grad)
sample = cat.sample()
log_prob = cat.log_prob(sample)
loss = log_prob.sum()
loss.backward()If one removes the print function right before the no_grad context manager there is going to be a runtime error when calling backward: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
It turns out that calling lazy property cat.logits within the no_grad context manager will create an attribute which doesn't require grad and if one uses cat.logits outside no_grad it still will not require grad, even though it should. You can see this strange interplay of lazy_property and no_grad in this example:
import torch
from torch.distributions.utils import lazy_property
class MyScores:
def __init__(self, score):
self.score = score
@lazy_property
def negative_score(self):
return -self.score
scores = torch.tensor([[1, 4, 2],
[8, 4, 4]], dtype=torch.float32, requires_grad=True)
my_scores = MyScores(scores)
with torch.no_grad():
print(my_scores.negative_score.requires_grad)
loss = my_scores.negative_score.sum()
loss.backward()This is not necessarily a bug, because the expected behaviour of the no_grad that whatever computations are created inside this context manager they are not going to require grad. So calling my_scores.negative_score within no_grad will create negative_score attribute which will not require grad as expected. The thing is that it is somewhat counterintuitive that my_scores.negative_score still doesn't require grad outside of the context manager. A user doesn't really know that it is a lazy property and expects it to have the same require_grad property as in the probs variable. Thus, the usage of the lazy_property within distributions can lead to the somewhat obscure bugs.
In my case I was sampling some noise, and used logits to obtain correct dtype and device:
with torch.no_grad():
uniforms = T.empty(sample_shape, dtype=cat.logits.dtype, device=cat.logits.device).uniform_()It was really surprising to see RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn later when calling loss.backward()
PyTorch version is 0.4