Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def assertTensorsEqual(a, b):
super(TestCase, self).assertEqual(len(x), len(y), message)
for x_, y_ in zip(x, y):
self.assertEqual(x_, y_, prec, message)
elif isinstance(x, bool) and isinstance(y, bool):
super(TestCase, self).assertEqual(x, y, message)
elif isinstance(x, Number) and isinstance(y, Number):
if abs(x) == float('inf') or abs(y) == float('inf'):
if allow_inf:
Expand Down
14 changes: 14 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,6 +1596,20 @@ def test_broadcast_double_backwards_gpu(self):
torch.randn(4, 4).cuda(),
torch.randn(4, 4).cuda())

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_broadcast_not_requiring_grad(self):
variables = [
Variable(torch.randn(1, 2).cuda(), requires_grad=True),
Variable(torch.randn(1, 2).cuda(), requires_grad=False),
Variable(torch.randn(1, 2).cuda(), requires_grad=False),
Variable(torch.randn(1, 2).cuda(), requires_grad=True),
Variable(torch.randn(1, 2).cuda(), requires_grad=True),
]
broadcasted_variables = Broadcast.apply((0, 1), *variables)
for output_idx, broadcasted_var in enumerate(broadcasted_variables):
input_var = variables[output_idx % len(variables)]
self.assertEqual(input_var.requires_grad, broadcasted_var.requires_grad)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_replicate(self):
module = nn.Linear(10, 5).float().cuda()
Expand Down
6 changes: 6 additions & 0 deletions torch/nn/parallel/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ def forward(ctx, target_gpus, *inputs):
ctx.num_inputs = len(inputs)
ctx.input_device = inputs[0].get_device()
outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
non_differentiables = []
for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
if not input_requires_grad:
for output in outputs:
non_differentiables.append(output[idx])
ctx.mark_non_differentiable(*non_differentiables)

This comment was marked as off-topic.

return tuple([t for tensors in outputs for t in tensors])

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions torch/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class DataParallel(Module):
other types will be a shallow copy and can be corrupted if written to in
the model's forward pass.

.. warning::
Forward and backwrad hooks defined on :attr:`module` and its submodules
won't be invoked anymore, unless the hooks are initialized in the
:meth:`forward` method.

This comment was marked as off-topic.

This comment was marked as off-topic.


Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
Expand Down
5 changes: 5 additions & 0 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class DistributedDataParallel(Module):
(e.g. BatchNorm stats) are broadcast form the module in process of rank
0, to all other replicas in the system in every iteration.

.. warning::
Forward and backwrad hooks defined on :attr:`module` and its submodules
won't be invoked anymore, unless the hooks are initialized in the
:meth:`forward` method.

Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
Expand Down