Skip to content

[Bug Report] DataParallel can't handle scalar output (PyTorch 0.4.0) #7956

@Jiaming-Liu

Description

@Jiaming-Liu

Seems like torch/nn/parallel/scatter_gather.py > Gather.apply(...) is broken by dim=0 outputs.

>>> import torch
>>> torch.__version__
'0.4.0'
>>> class Foo(torch.nn.Module):
...     def forward(self, x):
...         return x.mean() # this gives a scalar output
...         # return x.mean().view(1) # this is a quick fix
...     
>>> foo = torch.nn.DataParallel(Foo(),[0,1]).cuda()
>>> x = torch.zeros(2,2)
>>> foo(x)
Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 115, in forward
    return self.gather(outputs, self.output_device)
  File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 127, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
    return gather_map(outputs)
  File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 55, in gather_map
    return Gather.apply(target_device, dim, *outputs)
  File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 54, in forward
    ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
  File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 54, in <lambda>
    ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
RuntimeError: dimension specified as 0 but tensor has no dimensions

Related: #7568

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions