-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
Seems like torch/nn/parallel/scatter_gather.py > Gather.apply(...) is broken by dim=0 outputs.
>>> import torch
>>> torch.__version__
'0.4.0'
>>> class Foo(torch.nn.Module):
... def forward(self, x):
... return x.mean() # this gives a scalar output
... # return x.mean().view(1) # this is a quick fix
...
>>> foo = torch.nn.DataParallel(Foo(),[0,1]).cuda()
>>> x = torch.zeros(2,2)
>>> foo(x)Traceback (most recent call last):
File "<input>", line 1, in <module>
File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 115, in forward
return self.gather(outputs, self.output_device)
File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 127, in gather
return gather(outputs, output_device, dim=self.dim)
File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
return gather_map(outputs)
File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 55, in gather_map
return Gather.apply(target_device, dim, *outputs)
File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 54, in forward
ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
File "/home/?????/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 54, in <lambda>
ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
RuntimeError: dimension specified as 0 but tensor has no dimensions
Related: #7568
Metadata
Metadata
Assignees
Labels
No labels