Skip to content

Commit 8056399

Browse files
ssnlsoumith
authored andcommitted
Broacast output requires_grad if only corresponding input requires_grad (#5061)
1 parent c9ee47b commit 8056399

File tree

5 files changed

+32
-0
lines changed

5 files changed

+32
-0
lines changed

test/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ def assertTensorsEqual(a, b):
239239
super(TestCase, self).assertEqual(len(x), len(y), message)
240240
for x_, y_ in zip(x, y):
241241
self.assertEqual(x_, y_, prec, message)
242+
elif isinstance(x, bool) and isinstance(y, bool):
243+
super(TestCase, self).assertEqual(x, y, message)
242244
elif isinstance(x, Number) and isinstance(y, Number):
243245
if abs(x) == float('inf') or abs(y) == float('inf'):
244246
if allow_inf:

test/test_nn.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,6 +1615,20 @@ def test_broadcast_double_backwards_gpu(self):
16151615
torch.randn(4, 4).cuda(),
16161616
torch.randn(4, 4).cuda())
16171617

1618+
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
1619+
def test_broadcast_not_requiring_grad(self):
1620+
variables = [
1621+
Variable(torch.randn(1, 2).cuda(), requires_grad=True),
1622+
Variable(torch.randn(1, 2).cuda(), requires_grad=False),
1623+
Variable(torch.randn(1, 2).cuda(), requires_grad=False),
1624+
Variable(torch.randn(1, 2).cuda(), requires_grad=True),
1625+
Variable(torch.randn(1, 2).cuda(), requires_grad=True),
1626+
]
1627+
broadcasted_variables = Broadcast.apply((0, 1), *variables)
1628+
for output_idx, broadcasted_var in enumerate(broadcasted_variables):
1629+
input_var = variables[output_idx % len(variables)]
1630+
self.assertEqual(input_var.requires_grad, broadcasted_var.requires_grad)
1631+
16181632
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
16191633
def test_replicate(self):
16201634
module = nn.Linear(10, 5).float().cuda()

torch/nn/parallel/_functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ def forward(ctx, target_gpus, *inputs):
1515
ctx.num_inputs = len(inputs)
1616
ctx.input_device = inputs[0].get_device()
1717
outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
18+
non_differentiables = []
19+
for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
20+
if not input_requires_grad:
21+
for output in outputs:
22+
non_differentiables.append(output[idx])
23+
ctx.mark_non_differentiable(*non_differentiables)
1824
return tuple([t for tensors in outputs for t in tensors])
1925

2026
@staticmethod

torch/nn/parallel/data_parallel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ class DataParallel(Module):
2626
other types will be a shallow copy and can be corrupted if written to in
2727
the model's forward pass.
2828
29+
.. warning::
30+
Forward and backwrad hooks defined on :attr:`module` and its submodules
31+
won't be invoked anymore, unless the hooks are initialized in the
32+
:meth:`forward` method.
33+
2934
Args:
3035
module: module to be parallelized
3136
device_ids: CUDA devices (default: all devices)

torch/nn/parallel/distributed.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ class DistributedDataParallel(Module):
7878
(e.g. BatchNorm stats) are broadcast form the module in process of rank
7979
0, to all other replicas in the system in every iteration.
8080
81+
.. warning::
82+
Forward and backwrad hooks defined on :attr:`module` and its submodules
83+
won't be invoked anymore, unless the hooks are initialized in the
84+
:meth:`forward` method.
85+
8186
Args:
8287
module: module to be parallelized
8388
device_ids: CUDA devices (default: all devices)

0 commit comments

Comments
 (0)