Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 59 additions & 32 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,18 @@ class DistributedDataParallel(Module):
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
output_device: device location of output (default: device_ids[0])
broadcast_buffers: flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function.
(default: True)

Example::

>>> torch.distributed.init_process_group(world_size=4, init_method='...')
>>> net = torch.nn.DistributedDataParallel(model)
"""

def __init__(self, module, device_ids=None, output_device=None, dim=0):
def __init__(self, module, device_ids=None, output_device=None, dim=0,
broadcast_buffers=True):
super(DistributedDataParallel, self).__init__()
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
Expand All @@ -106,6 +110,7 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0):
self.module = module
self.device_ids = device_ids
self.output_device = output_device
self.broadcast_buffers = broadcast_buffers

MB = 1024 * 1024
# used for intra-node param sync and inter-node sync as well
Expand All @@ -130,32 +135,52 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0):
else:
self._module_copies = [self.module]

# Split parameters into buckets that will coalesce reductions
# TODO: different types need different buckets
t = None
for p in self.module.parameters():
tp = type(p.data)
if t is not None and t is not tp:
raise ValueError("DistributedDataParallel requires all parameters' data to be of the same type")
t = tp

self.bucket_sizes = []
self.bucket_map = {}
# Currently NCCL backend only supports single reduction thread/bucket
if dist._backend == dist.dist_backend.NCCL:
bucket_bytes_cap = float('inf')
else:
bucket_bytes_cap = 1 * MB
bucket_bytes = bucket_bytes_cap # to init the first bucket immediately
for param_tuple in zip(*map(lambda m: m.parameters(), self._module_copies)):
if param_tuple[0].requires_grad:
if bucket_bytes >= bucket_bytes_cap:
self.bucket_sizes.append(0)
bucket_bytes = 0

# This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
param_buckets = []
# Split the parameters into buckets and by types as well
for dev_idx, module in enumerate(self._module_copies):
param_buckets.append(list(_take_tensors(module.parameters(), bucket_bytes_cap)))

self.bucket_sizes = []
self.bucket_map = {}
param_types = set()

# We transpose param_buckets, so the loop is over buckets.
# param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems
for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)):
self.bucket_sizes.append(0)
# Now, we transpose again, so we iterate over bucket_elems, but getting tuples
# of params from each device.
for idx, param_tuple in enumerate(zip(*param_buckets_tuple)):

This comment was marked as off-topic.

This comment was marked as off-topic.

if idx == 0:
# Bucket parameter type tracking
bucket_param_type = param_tuple[0].type()
param_types.add(bucket_param_type)
# Only gloo and nccl support half-precision
if bucket_param_type == torch.cuda.HalfTensor and \
dist._backend != dist.dist_backend.NCCL and \
dist._backend != dist.dist_backend.GLOO:
raise RuntimeError("DistributedDataParallel currently only "
"supports half precision parameters "
"with Nccl and Gloo backend")
if not param_tuple[0].requires_grad:
continue
for p in param_tuple:
self.bucket_map[p] = len(self.bucket_sizes) - 1
bucket_bytes += p.numel() * p.element_size()
self.bucket_sizes[-1] += 1
self.bucket_map[p] = bucket_idx
self.bucket_sizes[bucket_idx] += 1

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


# TODO, adding mixed precision support in NCCL reduction code path
# This is because NCCL backend doesn't support multiple reduction
# bucket.
if len(param_types) > 1 and dist._backend == dist.dist_backend.NCCL:
raise RuntimeError("DistributedDataParallel currently doesn't "
"support mixed precision type for NCCL backend")

self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))]
self.bucket_events = [[None] * len(self.device_ids) for _ in range(len(self.bucket_sizes))]
Expand Down Expand Up @@ -225,17 +250,19 @@ def _sync_params(self):
for tensor, param in zip(tensors, module.parameters()):
param.data.set_(tensor)

buffers = list(self.module._all_buffers())
if len(buffers) > 0:
# cross-node buffer sync
self._dist_broadcast_coalesced(buffers, self.broadcast_bucket_size)

if len(self.device_ids) > 1:
# intra-node buffer sync
result = broadcast_coalesced(buffers, self.device_ids, self.broadcast_bucket_size)
for tensors, module in zip(result[1:], self._module_copies[1:]):
for tensor, buf in zip(tensors, module._all_buffers()):
buf.set_(tensor)
# module buffer sync
if self.broadcast_buffers:
buffers = list(self.module._all_buffers())
if len(buffers) > 0:
# cross-node buffer sync
self._dist_broadcast_coalesced(buffers, self.broadcast_bucket_size)

if len(self.device_ids) > 1:
# intra-node buffer sync
result = broadcast_coalesced(buffers, self.device_ids, self.broadcast_bucket_size)
for tensors, module in zip(result[1:], self._module_copies[1:]):
for tensor, buf in zip(tensors, module._all_buffers()):
buf.set_(tensor)

def _register_grad_hooks(self):
self._grad_accs = [] # need to keep them in scope
Expand Down