Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,31 @@ def test_module_to_argparse(self):
with self.assertRaises(TypeError):
net.to(cpu, torch.tensor(3, dtype=torch.long), non_blocking=True)

def test_module_apply_inplace_op(self):
def add_one_inplace(t):
return t.add_(1.0)

# Test that applying an in-place operation to a module would bump
# the module's parameters' version counter.
m = nn.Linear(20, 10)
pvm = m.weight.mul(m.weight)
m_weight_version_saved = m.weight._version
m = m._apply(add_one_inplace)
self.assertGreater(m.weight._version, m_weight_version_saved)
with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
pvm.backward(torch.randn(10, 20))

# Test that applying an in-place operation to a module would bump
# the module's parameters' gradients' version counter.
m = nn.Linear(20, 10)
m.weight.grad = torch.randn(10, 20).requires_grad_()
pgm = m.weight.grad.mul(m.weight.grad)
m_weight_grad_version_saved = m.weight.grad._version
m = m._apply(add_one_inplace)
self.assertGreater(m.weight.grad._version, m_weight_grad_version_saved)
with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
pgm.backward(torch.randn(10, 20))

def test_type(self):
l = nn.Linear(10, 20)
net = nn.Module()
Expand Down
10 changes: 6 additions & 4 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,13 @@ def _apply(self, fn):

for param in self._parameters.values():
if param is not None:
# Tensors stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param.data = fn(param.data)
with torch.no_grad():
param_applied = fn(param)
param.data = param_applied
if param._grad is not None:
param._grad.data = fn(param._grad.data)
with torch.no_grad():
grad_applied = fn(param._grad)
param._grad.data = grad_applied

for key, buf in self._buffers.items():
if buf is not None:
Expand Down