Skip to content

Commit c6a923f

Browse files
authored
Support modules that output scalar in Gather (and data parallel) (#7973)
* Support modules that output scalar in Gather (and data parallel) * Improve warning msg
1 parent 215abff commit c6a923f

File tree

3 files changed

+43
-3
lines changed

3 files changed

+43
-3
lines changed

test/test_nn.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2287,8 +2287,8 @@ def test_scatter_gpu(self):
22872287

22882288
def _test_gather(self, output_device):
22892289
inputs = (
2290-
Variable(torch.randn(2, 4).cuda(0), requires_grad=True),
2291-
Variable(torch.randn(2, 4).cuda(1), requires_grad=True)
2290+
torch.randn(2, 4, device='cuda:0', requires_grad=True),
2291+
torch.randn(2, 4, device='cuda:1', requires_grad=True),
22922292
)
22932293
result = dp.gather(inputs, output_device)
22942294
self.assertEqual(result.size(), torch.Size([4, 4]))
@@ -2306,6 +2306,27 @@ def _test_gather(self, output_device):
23062306
self.assertEqual(inputs[1].grad.data, grad[2:])
23072307
_assertGradAndGradgradChecks(self, lambda x, y: dp.gather((x, y), output_device), inputs)
23082308

2309+
# test scalar inputs, should stack into a vector in this case
2310+
inputs = (
2311+
torch.randn((), device='cuda:0', requires_grad=True),
2312+
torch.randn((), device='cuda:1', requires_grad=True),
2313+
)
2314+
result = dp.gather(inputs, output_device)
2315+
self.assertEqual(result.size(), torch.Size([2]))
2316+
self.assertEqual(result[0], inputs[0])
2317+
self.assertEqual(result[1], inputs[1])
2318+
if output_device != -1:
2319+
self.assertEqual(result.get_device(), output_device)
2320+
else:
2321+
self.assertFalse(result.is_cuda)
2322+
grad = torch.randn(2)
2323+
if output_device != -1:
2324+
grad = grad.cuda(output_device)
2325+
result.backward(grad)
2326+
self.assertEqual(inputs[0].grad, grad[0])
2327+
self.assertEqual(inputs[1].grad, grad[1])
2328+
_assertGradAndGradgradChecks(self, lambda x, y: dp.gather((x, y), output_device), inputs)
2329+
23092330
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
23102331
def test_gather_cpu(self):
23112332
self._test_gather(-1)

torch/nn/parallel/_functions.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import torch
24
import torch.cuda.comm as comm
35
from torch.autograd import Function
@@ -51,12 +53,23 @@ def forward(ctx, target_device, dim, *inputs):
5153
ctx.target_device = target_device
5254
ctx.dim = dim
5355
ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
56+
if all(t.dim() == 0 for t in inputs) and dim == 0:
57+
inputs = tuple(t.view(1) for t in inputs)
58+
warnings.warn('Was asked to gather along dimension 0, but all '
59+
'input tensors were scalars; will instead unsqueeze '
60+
'and return a vector.')
61+
ctx.unsqueezed_scalar = True
62+
else:
63+
ctx.unsqueezed_scalar = False
5464
ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
5565
return comm.gather(inputs, ctx.dim, ctx.target_device)
5666

5767
@staticmethod
5868
def backward(ctx, grad_output):
59-
return (None, None) + Scatter.apply(ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output)
69+
scattered_grads = Scatter.apply(ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output)
70+
if ctx.unsqueezed_scalar:
71+
scattered_grads = tuple(g[0] for g in scattered_grads)
72+
return (None, None) + scattered_grads
6073

6174

6275
class Scatter(Function):

torch/nn/parallel/data_parallel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ class DataParallel(Module):
6161
that each such hook be executed before the corresponding
6262
:meth:`~torch.nn.Module.forward` call of that device.
6363
64+
.. warning::
65+
When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
66+
:func:`forward`, this wrapper will return a vector of length equal to
67+
number of devices used in data parallelism, containing the result from
68+
each device.
69+
6470
.. note::
6571
There is a subtlety in using the
6672
``pack sequence -> recurrent network -> unpack sequence`` pattern in a

0 commit comments

Comments
 (0)