Skip to content

Commit c946db1

Browse files
authored
[distributions] Always enable grad when calculating lazy_property (#7708)
* Always enable grad when calculating lazy_property * Add test with MultiVariableNormal
1 parent 4bf0202 commit c946db1

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

test/test_distributions.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
SoftmaxTransform,
5454
StickBreakingTransform,
5555
identity_transform)
56-
from torch.distributions.utils import _finfo, probs_to_logits, softmax
56+
from torch.distributions.utils import _finfo, probs_to_logits, softmax, lazy_property
5757

5858
TEST_NUMPY = True
5959
try:
@@ -690,6 +690,31 @@ def test_enumerate_support_type(self):
690690
except NotImplementedError:
691691
pass
692692

693+
def test_lazy_property_grad(self):
694+
x = torch.randn(1, requires_grad=True)
695+
696+
class Dummy(object):
697+
@lazy_property
698+
def y(self):
699+
return x + 1
700+
701+
def test():
702+
x.grad = None
703+
Dummy().y.backward()
704+
self.assertEqual(x.grad, torch.ones(1))
705+
706+
test()
707+
with torch.no_grad():
708+
test()
709+
710+
mean = torch.randn(2)
711+
cov = torch.eye(2, requires_grad=True)
712+
distn = MultivariateNormal(mean, cov)
713+
with torch.no_grad():
714+
distn.scale_tril
715+
distn.scale_tril.sum().backward()
716+
self.assertIsNotNone(cov.grad)
717+
693718
def test_has_examples(self):
694719
distributions_with_examples = set(e.Dist for e in EXAMPLES)
695720
for Dist in globals().values():

torch/distributions/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def __init__(self, wrapped):
182182
def __get__(self, instance, obj_type=None):
183183
if instance is None:
184184
return self
185-
value = self.wrapped(instance)
185+
with torch.enable_grad():
186+
value = self.wrapped(instance)
186187
setattr(instance, self.wrapped.__name__, value)
187188
return value

0 commit comments

Comments
 (0)