Skip to content

Commit 54107ae

Browse files
weiyangfbfacebook-github-bot
authored andcommitted
convert output_device at data_parallel from torch.device to index (#10189)
Summary: - fixes #9984 Pull Request resolved: #10189 Differential Revision: D9545390 Pulled By: weiyangfb fbshipit-source-id: 3a6a705437553ba319e9fd4b7f676ff73857a27e
1 parent 045f862 commit 54107ae

File tree

8 files changed

+74
-32
lines changed

8 files changed

+74
-32
lines changed

test/test_c10d.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,8 +567,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
567567
def world_size(self):
568568
return 2
569569

570-
def _test_ddp_with_process_group(self, process_group):
571-
gpus = gpus_for_rank(self.world_size)[self.rank]
570+
def _test_ddp_with_process_group(self, process_group, gpus):
572571
model = Net()
573572
ddp_model = DistributedDataParallel(
574573
copy.deepcopy(model).cuda(gpus[0]),
@@ -620,14 +619,18 @@ def test_gloo_backend(self):
620619
options = c10d.ProcessGroupGloo.Options()
621620
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
622621
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
623-
self._test_ddp_with_process_group(process_group)
622+
gpus = gpus_for_rank(self.world_size)[self.rank]
623+
self._test_ddp_with_process_group(process_group, gpus)
624+
self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
624625

625626
@skip_if_not_multigpu
626627
@skip_if_not_nccl
627628
def test_nccl_backend(self):
628629
store = c10d.TCPStore('localhost', self.port, self.is_master)
629630
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
630-
self._test_ddp_with_process_group(process_group)
631+
gpus = gpus_for_rank(self.world_size)[self.rank]
632+
self._test_ddp_with_process_group(process_group, gpus)
633+
self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
631634

632635
@skip_if_not_multigpu
633636
def test_dist_broadcast_coalesced(self):

test/test_distributed.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,24 +1126,15 @@ def _test_DDP_2iter(
11261126
# Shuffle the input so that DDP input is different
11271127
input = input[torch.randperm(batch_size)]
11281128

1129-
@unittest.skipIf(
1130-
BACKEND != "nccl" and BACKEND != "gloo",
1131-
"Only Nccl & Gloo backend support DistributedDataParallel",
1132-
)
1133-
@skip_if_no_cuda_distributed
1134-
@skip_if_no_gpu
1135-
def test_DistributedDataParallel(self):
1129+
def _test_DistributedDataParallel(self, gpu_subset, rank, output_device=None):
11361130
# Run a simple end to end DDP model, use result of single node model
11371131
# as baseline
1138-
group, group_id, rank = self._init_global_test()
1139-
rank_to_GPU = self._init_multigpu_helper()
11401132

11411133
# cpu training setup
11421134
model = self._create_Net()
11431135

11441136
# single gpu training setup
11451137
model_gpu = copy.deepcopy(model)
1146-
gpu_subset = list(rank_to_GPU[rank])
11471138
model_gpu.cuda(gpu_subset[0])
11481139

11491140
# DDP training setup
@@ -1195,6 +1186,22 @@ def test_DistributedDataParallelCPU(self):
11951186
)
11961187
self._barrier()
11971188

1189+
@unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
1190+
"Only Nccl & Gloo backend support DistributedDataParallel")
1191+
@skip_if_no_cuda_distributed
1192+
@skip_if_no_gpu
1193+
def test_DistributedDataParallel(self):
1194+
group, group_id, rank = self._init_global_test()
1195+
rank_to_GPU = self._init_multigpu_helper()
1196+
gpus = list(rank_to_GPU[rank])
1197+
self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank)
1198+
1199+
# test output_device
1200+
self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
1201+
1202+
# test device_ids
1203+
gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
1204+
self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
11981205

11991206
if BACKEND == "gloo" or BACKEND == "nccl":
12001207
WORLD_SIZE = os.environ["WORLD_SIZE"]

test/test_nn.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3154,6 +3154,24 @@ def forward(self, input):
31543154
self.assertEqual(out.get_device(), 0)
31553155
self.assertEqual(out.data, expected_out)
31563156

3157+
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
3158+
@skipIfRocm
3159+
def test_data_parallel_device_args(self):
3160+
cuda0 = torch.device('cuda:0')
3161+
cuda1 = torch.device('cuda:1')
3162+
3163+
# test output_device
3164+
l = nn.Linear(10, 5).to(cuda0, torch.float)
3165+
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
3166+
out = dp.data_parallel(l, i, device_ids=(0, 1), output_device=cuda0)
3167+
self.assertEqual(out, l(i))
3168+
3169+
# test device_ids
3170+
l = nn.Linear(10, 5).to(cuda0, torch.float)
3171+
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
3172+
out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0)
3173+
self.assertEqual(out, l(i))
3174+
31573175
def test_state_dict(self):
31583176
l = nn.Linear(5, 5)
31593177
block = nn.Module()

torch/nn/parallel/_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torch.cuda.comm as comm
55
from torch.autograd import Function
6+
from torch.cuda._utils import _get_device_index
67

78

89
class Broadcast(Function):
@@ -11,6 +12,7 @@ class Broadcast(Function):
1112
def forward(ctx, target_gpus, *inputs):
1213
if not all(input.is_cuda for input in inputs):
1314
raise TypeError('Broadcast function not implemented for CPU tensors')
15+
target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
1416
ctx.target_gpus = target_gpus
1517
if len(inputs) == 0:
1618
return tuple()
@@ -50,6 +52,7 @@ class Gather(Function):
5052
@staticmethod
5153
def forward(ctx, target_device, dim, *inputs):
5254
assert all(map(lambda i: i.is_cuda, inputs))
55+
target_device = _get_device_index(target_device, True)
5356
ctx.target_device = target_device
5457
ctx.dim = dim
5558
ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
@@ -76,6 +79,7 @@ class Scatter(Function):
7679

7780
@staticmethod
7881
def forward(ctx, target_gpus, chunk_sizes, dim, input):
82+
target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
7983
ctx.dim = dim
8084
ctx.input_device = input.get_device() if input.is_cuda else -1
8185
streams = None

torch/nn/parallel/data_parallel.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .scatter_gather import scatter_kwargs, gather
66
from .replicate import replicate
77
from .parallel_apply import parallel_apply
8+
from torch.cuda._utils import _get_device_index
89

910

1011
def _check_balance(device_ids):
@@ -13,7 +14,7 @@ def _check_balance(device_ids):
1314
has less than 75% of the memory or cores of GPU {}. You can do so by setting
1415
the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
1516
environment variable."""
16-
17+
device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
1718
dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]
1819

1920
def warn_imbalance(get_prop):
@@ -77,9 +78,9 @@ class DataParallel(Module):
7778
7879
7980
Args:
80-
module: module to be parallelized
81-
device_ids: CUDA devices (default: all devices)
82-
output_device: device location of output (default: device_ids[0])
81+
module (Module): module to be parallelized
82+
device_ids (list of int or torch.device): CUDA devices (default: all devices)
83+
output_device (int or torch.device): device location of output (default: device_ids[0])
8384
8485
Attributes:
8586
module (Module): the module to be parallelized
@@ -104,10 +105,11 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0):
104105
device_ids = list(range(torch.cuda.device_count()))
105106
if output_device is None:
106107
output_device = device_ids[0]
108+
107109
self.dim = dim
108110
self.module = module
109-
self.device_ids = device_ids
110-
self.output_device = output_device
111+
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
112+
self.output_device = _get_device_index(output_device, True)
111113

112114
_check_balance(self.device_ids)
113115

@@ -143,10 +145,10 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo
143145
This is the functional version of the DataParallel module.
144146
145147
Args:
146-
module: the module to evaluate in parallel
147-
inputs: inputs to the module
148-
device_ids: GPU ids on which to replicate module
149-
output_device: GPU location of the output Use -1 to indicate the CPU.
148+
module (Module): the module to evaluate in parallel
149+
inputs (tensor): inputs to the module
150+
device_ids (list of int or torch.device): GPU ids on which to replicate module
151+
output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU.
150152
(default: device_ids[0])
151153
Returns:
152154
a Tensor containing the result of module(input) located on

torch/nn/parallel/distributed.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .replicate import replicate
1313
from .scatter_gather import scatter_kwargs, gather
1414
from .parallel_apply import parallel_apply
15+
from torch.cuda._utils import _get_device_index
1516

1617

1718
class DistributedDataParallel(Module):
@@ -90,10 +91,10 @@ class DistributedDataParallel(Module):
9091
:meth:`forward` method.
9192
9293
Args:
93-
module: module to be parallelized
94-
device_ids: CUDA devices (default: all devices)
95-
output_device: device location of output (default: device_ids[0])
96-
broadcast_buffers: flag that enables syncing (broadcasting) buffers of
94+
module (Module): module to be parallelized
95+
device_ids (list of int or torch.device): CUDA devices (default: all devices)
96+
output_device (int or torch.device): device location of output (default: device_ids[0])
97+
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
9798
the module at beginning of the forward function.
9899
(default: True)
99100
process_group: the c10d process group to be used for distributed data
@@ -133,8 +134,8 @@ def __init__(self, module, device_ids=None,
133134

134135
self.dim = dim
135136
self.module = module
136-
self.device_ids = device_ids
137-
self.output_device = output_device
137+
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
138+
self.output_device = _get_device_index(output_device, True)
138139
self.broadcast_buffers = broadcast_buffers
139140

140141
self.allreduce_opts = dist.AllreduceOptions()

torch/nn/parallel/parallel_apply.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import threading
22
import torch
3+
from torch.cuda._utils import _get_device_index
34

45

56
def get_a_var(obj):
@@ -22,6 +23,11 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
2223
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
2324
on each of :attr:`devices`.
2425
26+
Args:
27+
modules (Module): modules to be parallelized
28+
inputs (tensor): inputs to the modules
29+
devices (list of int or torch.device): CUDA devices
30+
2531
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
2632
:attr:`devices` (if given) should all have same length. Moreover, each
2733
element of :attr:`inputs` can either be a single object as the only argument
@@ -36,7 +42,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
3642
assert len(modules) == len(devices)
3743
else:
3844
devices = [None] * len(modules)
39-
45+
devices = list(map(lambda x: _get_device_index(x, True), devices))
4046
lock = threading.Lock()
4147
results = {}
4248
grad_enabled = torch.is_grad_enabled()

torch/nn/parallel/replicate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import torch.cuda.comm as comm
2+
from torch.cuda._utils import _get_device_index
23

34

45
def replicate(network, devices, detach=False):
56
from ._functions import Broadcast
67

7-
devices = tuple(devices)
8+
devices = list(map(lambda x: _get_device_index(x, True), devices))
89
num_replicas = len(devices)
910

1011
params = list(network.parameters())

0 commit comments

Comments
 (0)