Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions test/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
def world_size(self):
return 2

def _test_ddp_with_process_group(self, process_group):
gpus = gpus_for_rank(self.world_size)[self.rank]
def _test_ddp_with_process_group(self, process_group, gpus):
model = Net()
ddp_model = DistributedDataParallel(
copy.deepcopy(model).cuda(gpus[0]),
Expand Down Expand Up @@ -620,14 +619,18 @@ def test_gloo_backend(self):
options = c10d.ProcessGroupGloo.Options()
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
self._test_ddp_with_process_group(process_group)
gpus = gpus_for_rank(self.world_size)[self.rank]
self._test_ddp_with_process_group(process_group, gpus)
self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))

@skip_if_not_multigpu
@skip_if_not_nccl
def test_nccl_backend(self):
store = c10d.TCPStore('localhost', self.port, self.is_master)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
self._test_ddp_with_process_group(process_group)
gpus = gpus_for_rank(self.world_size)[self.rank]
self._test_ddp_with_process_group(process_group, gpus)
self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))

@skip_if_not_multigpu
def test_dist_broadcast_coalesced(self):
Expand Down
27 changes: 17 additions & 10 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,24 +1126,15 @@ def _test_DDP_2iter(
# Shuffle the input so that DDP input is different
input = input[torch.randperm(batch_size)]

@unittest.skipIf(
BACKEND != "nccl" and BACKEND != "gloo",
"Only Nccl & Gloo backend support DistributedDataParallel",
)
@skip_if_no_cuda_distributed
@skip_if_no_gpu
def test_DistributedDataParallel(self):
def _test_DistributedDataParallel(self, gpu_subset, rank, output_device=None):
# Run a simple end to end DDP model, use result of single node model
# as baseline
group, group_id, rank = self._init_global_test()
rank_to_GPU = self._init_multigpu_helper()

# cpu training setup
model = self._create_Net()

# single gpu training setup
model_gpu = copy.deepcopy(model)
gpu_subset = list(rank_to_GPU[rank])
model_gpu.cuda(gpu_subset[0])

# DDP training setup
Expand Down Expand Up @@ -1195,6 +1186,22 @@ def test_DistributedDataParallelCPU(self):
)
self._barrier()

@unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
"Only Nccl & Gloo backend support DistributedDataParallel")
@skip_if_no_cuda_distributed
@skip_if_no_gpu
def test_DistributedDataParallel(self):
group, group_id, rank = self._init_global_test()
rank_to_GPU = self._init_multigpu_helper()
gpus = list(rank_to_GPU[rank])
self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank)

# test output_device
self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))

# test device_ids
gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))

if BACKEND == "gloo" or BACKEND == "nccl":
WORLD_SIZE = os.environ["WORLD_SIZE"]
Expand Down
18 changes: 18 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3154,6 +3154,24 @@ def forward(self, input):
self.assertEqual(out.get_device(), 0)
self.assertEqual(out.data, expected_out)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@skipIfRocm
def test_data_parallel_device_args(self):
cuda0 = torch.device('cuda:0')
cuda1 = torch.device('cuda:1')

# test output_device
l = nn.Linear(10, 5).to(cuda0, torch.float)
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
out = dp.data_parallel(l, i, device_ids=(0, 1), output_device=cuda0)
self.assertEqual(out, l(i))

# test device_ids
l = nn.Linear(10, 5).to(cuda0, torch.float)
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0)
self.assertEqual(out, l(i))

def test_state_dict(self):
l = nn.Linear(5, 5)
block = nn.Module()
Expand Down
4 changes: 4 additions & 0 deletions torch/nn/parallel/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.cuda.comm as comm
from torch.autograd import Function
from torch.cuda._utils import _get_device_index


class Broadcast(Function):
Expand All @@ -11,6 +12,7 @@ class Broadcast(Function):
def forward(ctx, target_gpus, *inputs):
if not all(input.is_cuda for input in inputs):
raise TypeError('Broadcast function not implemented for CPU tensors')
target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
ctx.target_gpus = target_gpus
if len(inputs) == 0:
return tuple()
Expand Down Expand Up @@ -50,6 +52,7 @@ class Gather(Function):
@staticmethod
def forward(ctx, target_device, dim, *inputs):
assert all(map(lambda i: i.is_cuda, inputs))
target_device = _get_device_index(target_device, True)
ctx.target_device = target_device
ctx.dim = dim
ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
Expand All @@ -76,6 +79,7 @@ class Scatter(Function):

@staticmethod
def forward(ctx, target_gpus, chunk_sizes, dim, input):
target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
ctx.dim = dim
ctx.input_device = input.get_device() if input.is_cuda else -1
streams = None
Expand Down
22 changes: 12 additions & 10 deletions torch/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .scatter_gather import scatter_kwargs, gather
from .replicate import replicate
from .parallel_apply import parallel_apply
from torch.cuda._utils import _get_device_index


def _check_balance(device_ids):
Expand All @@ -13,7 +14,7 @@ def _check_balance(device_ids):
has less than 75% of the memory or cores of GPU {}. You can do so by setting
the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
environment variable."""

device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]

def warn_imbalance(get_prop):
Expand Down Expand Up @@ -77,9 +78,9 @@ class DataParallel(Module):


Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
output_device: device location of output (default: device_ids[0])
module (Module): module to be parallelized
device_ids (list of int or torch.device): CUDA devices (default: all devices)
output_device (int or torch.device): device location of output (default: device_ids[0])

Attributes:
module (Module): the module to be parallelized
Expand All @@ -104,10 +105,11 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0):
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]

self.dim = dim
self.module = module
self.device_ids = device_ids
self.output_device = output_device
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = _get_device_index(output_device, True)

_check_balance(self.device_ids)

Expand Down Expand Up @@ -143,10 +145,10 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo
This is the functional version of the DataParallel module.

Args:
module: the module to evaluate in parallel
inputs: inputs to the module
device_ids: GPU ids on which to replicate module
output_device: GPU location of the output Use -1 to indicate the CPU.
module (Module): the module to evaluate in parallel
inputs (tensor): inputs to the module
device_ids (list of int or torch.device): GPU ids on which to replicate module
output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU.
(default: device_ids[0])
Returns:
a Tensor containing the result of module(input) located on
Expand Down
13 changes: 7 additions & 6 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .replicate import replicate
from .scatter_gather import scatter_kwargs, gather
from .parallel_apply import parallel_apply
from torch.cuda._utils import _get_device_index


class DistributedDataParallel(Module):
Expand Down Expand Up @@ -90,10 +91,10 @@ class DistributedDataParallel(Module):
:meth:`forward` method.

Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
output_device: device location of output (default: device_ids[0])
broadcast_buffers: flag that enables syncing (broadcasting) buffers of
module (Module): module to be parallelized
device_ids (list of int or torch.device): CUDA devices (default: all devices)
output_device (int or torch.device): device location of output (default: device_ids[0])
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function.
(default: True)
process_group: the c10d process group to be used for distributed data
Expand Down Expand Up @@ -133,8 +134,8 @@ def __init__(self, module, device_ids=None,

self.dim = dim
self.module = module
self.device_ids = device_ids
self.output_device = output_device
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = _get_device_index(output_device, True)
self.broadcast_buffers = broadcast_buffers

self.allreduce_opts = dist.AllreduceOptions()
Expand Down
8 changes: 7 additions & 1 deletion torch/nn/parallel/parallel_apply.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import threading
import torch
from torch.cuda._utils import _get_device_index


def get_a_var(obj):
Expand All @@ -22,6 +23,11 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
on each of :attr:`devices`.

Args:
modules (Module): modules to be parallelized
inputs (tensor): inputs to the modules
devices (list of int or torch.device): CUDA devices

:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument
Expand All @@ -36,7 +42,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)

devices = list(map(lambda x: _get_device_index(x, True), devices))

This comment was marked as off-topic.

lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
Expand Down
3 changes: 2 additions & 1 deletion torch/nn/parallel/replicate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch.cuda.comm as comm
from torch.cuda._utils import _get_device_index


def replicate(network, devices, detach=False):
from ._functions import Broadcast

devices = tuple(devices)
devices = list(map(lambda x: _get_device_index(x, True), devices))
num_replicas = len(devices)

params = list(network.parameters())
Expand Down