Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 103 additions & 22 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,14 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,
self.output_device = output_device
self.broadcast_buffers = broadcast_buffers

# Flag used by the NCCL backend to make sure we only reduce gradients
# one time in the execution engine
self.need_reduction = False

MB = 1024 * 1024
# used for intra-node param sync and inter-node sync as well
self.broadcast_bucket_size = 10 * MB
self.nccl_reduce_bucket_size = 256 * MB

# Sync params and buffers
module_states = list(self.module.state_dict().values())
Expand All @@ -135,11 +140,15 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,
else:
self._module_copies = [self.module]

# Currently NCCL backend only supports single reduction thread/bucket
# For NCCL backend, since every single NCCL call is asynchoronous, we
# therefore directly enqueue all the NCCL reduction calls to the
# default CUDA stream without spawning up other reduction threads.
# This achieves the best performance.
if dist._backend == dist.dist_backend.NCCL:
bucket_bytes_cap = float('inf')
else:
bucket_bytes_cap = 1 * MB
self._register_nccl_grad_hook()
return

bucket_bytes_cap = 1 * MB

# This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
param_buckets = []
Expand All @@ -149,7 +158,6 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,

self.bucket_sizes = []
self.bucket_map = {}
param_types = set()

# We transpose param_buckets, so the loop is over buckets.
# param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems
Expand All @@ -161,10 +169,8 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,
if idx == 0:
# Bucket parameter type tracking
bucket_param_type = param_tuple[0].type()
param_types.add(bucket_param_type)
# Only gloo and nccl support half-precision
if bucket_param_type == torch.cuda.HalfTensor and \
dist._backend != dist.dist_backend.NCCL and \
dist._backend != dist.dist_backend.GLOO:
raise RuntimeError("DistributedDataParallel currently only "
"supports half precision parameters "
Expand All @@ -175,13 +181,6 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,
self.bucket_map[p] = bucket_idx
self.bucket_sizes[bucket_idx] += 1

# TODO, adding mixed precision support in NCCL reduction code path
# This is because NCCL backend doesn't support multiple reduction
# bucket.
if len(param_types) > 1 and dist._backend == dist.dist_backend.NCCL:
raise RuntimeError("DistributedDataParallel currently doesn't "
"support mixed precision type for NCCL backend")

self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))]
self.bucket_events = [[None] * len(self.device_ids) for _ in range(len(self.bucket_sizes))]
self.reduced = [False] * len(self.bucket_sizes)
Expand All @@ -193,16 +192,22 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,

def __getstate__(self):
attrs = copy.copy(self.__dict__)
del attrs['_grad_accs'], attrs['_reduction_queues'], attrs['_reduction_streams'], \
attrs['_reduction_threads'], attrs['_nccl_streams'], attrs['_default_streams']
if dist._backend != dist.dist_backend.NCCL:
del attrs['_grad_accs'], attrs['_reduction_queues'], \
attrs['_reduction_streams'], attrs['_reduction_threads'], \
attrs['_nccl_streams'], attrs['_default_streams']
return attrs

def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state)
self._register_grad_hooks()
self._start_reduction_threads()
if dist._backend == dist.dist_backend.NCCL:
self._register_nccl_grad_hook()
else:
self._register_grad_hooks()
self._start_reduction_threads()

def forward(self, *inputs, **kwargs):
self.need_reduction = True
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
self._sync_params()
if len(self.device_ids) == 1:
Expand Down Expand Up @@ -274,7 +279,86 @@ def _register_grad_hooks(self):
grad_acc.register_hook(self._make_param_hook(p, device_idx))
self._grad_accs.append(grad_acc)

def _register_nccl_grad_hook(self):
"""
This function registers the callback all-reduction function for the
NCCL backend. All gradients will be all reduced in one single step.
The NCCL reduction will directly be enqueued into the
default CUDA stream. Therefore, no synchronization is needed.
"""
# Creating a new group
self.nccl_reduction_group_id = dist.new_group()

def reduction_fn_nccl():
# This function only needs to be called once
if not self.need_reduction:
return

self.need_reduction = False
all_grads = [[] for _ in range(len(self._module_copies))]
all_grads_buckets_iters = []

# Bucketing all the gradients
for dev_idx, module in enumerate(self._module_copies):
for param in module.parameters():
if not param.requires_grad or param.grad is None:
continue
if param.grad.requires_grad:
raise RuntimeError("DistributedDataParallel only works "
"with gradients that don't require "
"grad")
# Adding the gradients for reduction
all_grads[dev_idx].append(param.grad.data)

# Now bucketing the parameters
dev_grads_buckets = _take_tensors(all_grads[dev_idx],
self.nccl_reduce_bucket_size)

all_grads_buckets_iters.append(dev_grads_buckets)

# Now reduce each bucket one after another
for grads_batch in zip(*all_grads_buckets_iters):
grads_batch_coalesced = []
# Coalesce each bucket
for dev_idx, dev_grads_batch in enumerate(grads_batch):
dev_id = self.device_ids[dev_idx]
with torch.cuda.device(dev_id):
dev_grads_batch_coalesced = _flatten_dense_tensors(dev_grads_batch)
grads_batch_coalesced.append(dev_grads_batch_coalesced)

# We will only use device 0's results, but this single op should be
# faster than doing the following two operation sequentially:
# (1) intra-node reduce to lead GPU, followed by
# (2) inter-node allreduce for all the first lead GPUs in all nodes
dist.all_reduce_multigpu(grads_batch_coalesced,
group=self.nccl_reduction_group_id)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


# Now only work on the first device of self.device_ids, uncoalesce
# the gradients for each bucket
grads_batch_coalesced[0] /= dist.get_world_size()
grads_batch_reduced = _unflatten_dense_tensors(grads_batch_coalesced[0], grads_batch[0])
for grad, reduced in zip(grads_batch[0], grads_batch_reduced):
grad.copy_(reduced)

# clear the gradients and save memory for replicas
for module in self._module_copies[1:]:
for param in module.parameters():
if param.requires_grad:
param.grad = None
param.data.set_()

# Now register the reduction hook on the parameters
for p in self.module.parameters():
if not p.requires_grad:
continue

def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(reduction_fn_nccl)

p.register_hook(allreduce_hook)

def _make_param_hook(self, param, device_idx):

bucket_idx = self.bucket_map[param]

def distributed_data_parallel_hook(*unused):
Expand Down Expand Up @@ -349,10 +433,7 @@ def _start_reduction_threads(self):
# We only use the first device for distributed reductions
dist._register_stream(reduction_streams[0])

if dist._backend == dist.dist_backend.NCCL:
group_id = dist.group.WORLD
else:
group_id = dist.new_group()
group_id = dist.new_group()

self._reduction_threads.append(threading.Thread(
target=self._reduction_thread_fn,
Expand Down