Skip to content

Commit 4b1df5c

Browse files
Will Fengfacebook-github-bot
authored andcommitted
Use fn(param) instead of fn(param.data) in nn.Module._apply (#21865)
Summary: When we pass `fn` to `nn.Module._apply()` and `fn` is an in-place operation, the correct behavior should also include bumping the parameters' and their gradients' version counters. This PR fixes the old incorrect behavior and makes sure the new behavior is right. Note that this PR is BC-breaking in the following way: Previously, passing an in-place operation to `nn.Module._apply()` does not bump the module's parameters' and their gradients' version counters. After this PR, the module's parameters' and their gradients' version counters will be correctly bumped by the in-place operation, which will invalidate them in any autograd graph they previously participate in. Pull Request resolved: #21865 Differential Revision: D15881952 Pulled By: yf225 fbshipit-source-id: 62f9244a4283a110147e9f20145ff232a5579fbd
1 parent abd6cff commit 4b1df5c

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

test/test_nn.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,6 +1577,31 @@ def test_module_to_argparse(self):
15771577
with self.assertRaises(TypeError):
15781578
net.to(cpu, torch.tensor(3, dtype=torch.long), non_blocking=True)
15791579

1580+
def test_module_apply_inplace_op(self):
1581+
def add_one_inplace(t):
1582+
return t.add_(1.0)
1583+
1584+
# Test that applying an in-place operation to a module would bump
1585+
# the module's parameters' version counter.
1586+
m = nn.Linear(20, 10)
1587+
pvm = m.weight.mul(m.weight)
1588+
m_weight_version_saved = m.weight._version
1589+
m = m._apply(add_one_inplace)
1590+
self.assertGreater(m.weight._version, m_weight_version_saved)
1591+
with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
1592+
pvm.backward(torch.randn(10, 20))
1593+
1594+
# Test that applying an in-place operation to a module would bump
1595+
# the module's parameters' gradients' version counter.
1596+
m = nn.Linear(20, 10)
1597+
m.weight.grad = torch.randn(10, 20).requires_grad_()
1598+
pgm = m.weight.grad.mul(m.weight.grad)
1599+
m_weight_grad_version_saved = m.weight.grad._version
1600+
m = m._apply(add_one_inplace)
1601+
self.assertGreater(m.weight.grad._version, m_weight_grad_version_saved)
1602+
with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
1603+
pgm.backward(torch.randn(10, 20))
1604+
15801605
def test_type(self):
15811606
l = nn.Linear(10, 20)
15821607
net = nn.Module()

torch/nn/modules/module.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,13 @@ def _apply(self, fn):
195195

196196
for param in self._parameters.values():
197197
if param is not None:
198-
# Tensors stored in modules are graph leaves, and we don't
199-
# want to create copy nodes, so we have to unpack the data.
200-
param.data = fn(param.data)
198+
with torch.no_grad():
199+
param_applied = fn(param)
200+
param.data = param_applied
201201
if param._grad is not None:
202-
param._grad.data = fn(param._grad.data)
202+
with torch.no_grad():
203+
grad_applied = fn(param._grad)
204+
param._grad.data = grad_applied
203205

204206
for key, buf in self._buffers.items():
205207
if buf is not None:

0 commit comments

Comments
 (0)