Skip to content

torch.device and torch.nn.parallel.data_parallel compatibility #9984

@PetrochukM

Description

@PetrochukM

Issue description

torch.nn.parallel.data_parallel does not accept torch.device

Code example

import torch
net = torch.nn.LSTM(10, 10)
torch.nn.parallel.data_parallel(module=net.cuda(), inputs=torch.randn(1, 1, 10).cuda(), output_device=torch.device('cuda')) # Does not work
torch.nn.parallel.data_parallel(module=net.cuda(), inputs=torch.randn(1, 1, 10).cuda(), output_device=torch.device('cuda').index) # Works
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/michaelp/.local/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 169, in data_parallel
    return gather(outputs, output_device, dim)
  File "/home/michaelp/.local/lib/python3.5/site-packages/torch/nn/parallel/scatter_gather.py", line 67, in gather
    return gather_map(outputs)
  File "/home/michaelp/.local/lib/python3.5/site-packages/torch/nn/parallel/scatter_gather.py", line 62, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
  File "/home/michaelp/.local/lib/python3.5/site-packages/torch/nn/parallel/scatter_gather.py", line 54, in gather_map
    return Gather.apply(target_device, dim, *outputs)
  File "/home/michaelp/.local/lib/python3.5/site-packages/torch/nn/parallel/_functions.py", line 65, in forward
    return comm.gather(inputs, ctx.dim, ctx.target_device)
  File "/home/michaelp/.local/lib/python3.5/site-packages/torch/cuda/comm.py", line 160, in gather
    return torch._C._gather(tensors, dim, destination)
TypeError: _gather(): incompatible function arguments. The following argument types are supported:
    1. (tensors: List[at::Tensor], dim: int, destination_index: Optional[int]) -> at::Tensor

Invoked with: (tensor([[[-0.0603, -0.0166, -0.0857,  0.0781,  0.1473,  0.1543,  0.0239,
           0.0820,  0.0887, -0.0435]]],
       device='cuda:0', grad_fn=<CudnnRnnBackward>),), 0, device(type='cuda')

System Info

  • PyTorch or Caffe2: PyTorch
  • How you installed PyTorch (conda, pip, source): pip
  • PyTorch version: 0.4.1
  • Python version: 3.5.2

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions