Skip to content

Commit d11b7fb

Browse files
colesburyezyang
authored andcommitted
Don't modify requires_grad when running DataParallel in no_grad mode (#5880)
Previously, running DataParallel in no_grad mode would change the requires_grad property of the network's parameters to False. The issue is that Broadcast returns aliases of the inputs for the source device. In no_grad mode, it would deatch these inputs in-place. Fixes #5851
1 parent 24fca0e commit d11b7fb

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

test/test_autograd.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1527,6 +1527,23 @@ def test_no_grad_modifies_version(self):
15271527
self.assertRaisesRegex(RuntimeError, 'modified by an inplace operation',
15281528
lambda: z.backward())
15291529

1530+
def test_no_grad_input(self):
1531+
class MyFunction(Function):
1532+
@staticmethod
1533+
def forward(self, x):
1534+
return x
1535+
1536+
@staticmethod
1537+
def backward(self, grad_output):
1538+
return grad_output
1539+
1540+
x = torch.randn(5, requires_grad=True)
1541+
with torch.no_grad():
1542+
y = MyFunction.apply(x)
1543+
1544+
self.assertTrue(x.requires_grad)
1545+
self.assertIsNone(y.grad_fn)
1546+
15301547
def test_backward_copy(self):
15311548
# This tests checks backward engine for a very subtle bug that appreared
15321549
# in one of the initial versions of autograd. Gradients tensors were

test/test_nn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,6 +1905,15 @@ def test_broadcast_not_requiring_grad(self):
19051905
input_var = variables[output_idx % len(variables)]
19061906
self.assertEqual(input_var.requires_grad, broadcasted_var.requires_grad)
19071907

1908+
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
1909+
def test_broadcast_no_grad(self):
1910+
x = torch.randn(1, 2, dtype=torch.cuda.float32, requires_grad=True)
1911+
with torch.no_grad():
1912+
broadcasted = Broadcast.apply((0, 1), x)
1913+
self.assertTrue(x.requires_grad)
1914+
for output in broadcasted:
1915+
self.assertFalse(output.requires_grad)
1916+
19081917
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
19091918
def test_replicate(self):
19101919
module = nn.Linear(10, 5).float().cuda()

torch/csrc/autograd/python_function.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,15 +373,22 @@ static void _wrap_outputs(THPFunction *self,
373373
auto set_history = [&](Variable& var, uint32_t output_nr, bool is_input, bool is_modified,
374374
bool is_differentiable) {
375375
if (!is_differentiable) {
376-
if (!var.requires_grad()) return;
376+
if (!var.requires_grad()) {
377+
return;
378+
}
377379
// NB: we don't support returning non-differentiable views that could require grad
378-
// (this could happen if someone were to return an input to the function).
379380
if (var.is_view()) {
380381
throw std::runtime_error("Returning Variables sharing storage with other Variables "
381382
"that require grad is not supported in Python functions. "
382383
"Please submit a feature request if you hit this error.");
383384
}
384-
var.detach_();
385+
// Return detached aliases of inputs, instead of changing their requires_grad
386+
// property.
387+
if (is_input) {
388+
var = var.detach();
389+
} else {
390+
var.detach_();
391+
}
385392
} else if (is_modified) {
386393
if (var.is_leaf() && var.requires_grad()) {
387394
throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");

0 commit comments

Comments
 (0)