Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2287,8 +2287,8 @@ def test_scatter_gpu(self):

def _test_gather(self, output_device):
inputs = (
Variable(torch.randn(2, 4).cuda(0), requires_grad=True),
Variable(torch.randn(2, 4).cuda(1), requires_grad=True)
torch.randn(2, 4, device='cuda:0', requires_grad=True),
torch.randn(2, 4, device='cuda:1', requires_grad=True),
)
result = dp.gather(inputs, output_device)
self.assertEqual(result.size(), torch.Size([4, 4]))
Expand All @@ -2306,6 +2306,27 @@ def _test_gather(self, output_device):
self.assertEqual(inputs[1].grad.data, grad[2:])
_assertGradAndGradgradChecks(self, lambda x, y: dp.gather((x, y), output_device), inputs)

# test scalar inputs, should stack into a vector in this case
inputs = (
torch.randn((), device='cuda:0', requires_grad=True),
torch.randn((), device='cuda:1', requires_grad=True),
)
result = dp.gather(inputs, output_device)
self.assertEqual(result.size(), torch.Size([2]))
self.assertEqual(result[0], inputs[0])
self.assertEqual(result[1], inputs[1])
if output_device != -1:
self.assertEqual(result.get_device(), output_device)
else:
self.assertFalse(result.is_cuda)
grad = torch.randn(2)
if output_device != -1:
grad = grad.cuda(output_device)
result.backward(grad)
self.assertEqual(inputs[0].grad, grad[0])
self.assertEqual(inputs[1].grad, grad[1])
_assertGradAndGradgradChecks(self, lambda x, y: dp.gather((x, y), output_device), inputs)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_gather_cpu(self):
self._test_gather(-1)
Expand Down
15 changes: 14 additions & 1 deletion torch/nn/parallel/_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import torch
import torch.cuda.comm as comm
from torch.autograd import Function
Expand Down Expand Up @@ -51,12 +53,23 @@ def forward(ctx, target_device, dim, *inputs):
ctx.target_device = target_device
ctx.dim = dim
ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
if all(t.dim() == 0 for t in inputs) and dim == 0:

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

inputs = tuple(t.view(1) for t in inputs)
warnings.warn('Was asked to gather along dimension 0, but all '
'input tensors were scalars; will instead unsqueeze '
'and return a vector.')
ctx.unsqueezed_scalar = True
else:
ctx.unsqueezed_scalar = False
ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
return comm.gather(inputs, ctx.dim, ctx.target_device)

@staticmethod
def backward(ctx, grad_output):
return (None, None) + Scatter.apply(ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output)
scattered_grads = Scatter.apply(ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output)
if ctx.unsqueezed_scalar:
scattered_grads = tuple(g[0] for g in scattered_grads)
return (None, None) + scattered_grads


class Scatter(Function):
Expand Down
6 changes: 6 additions & 0 deletions torch/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class DataParallel(Module):
that each such hook be executed before the corresponding
:meth:`~torch.nn.Module.forward` call of that device.

.. warning::
When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
:func:`forward`, this wrapper will return a vector of length equal to
number of devices used in data parallelism, containing the result from
each device.

.. note::
There is a subtlety in using the
``pack sequence -> recurrent network -> unpack sequence`` pattern in a
Expand Down