Skip to content

Commit 97e1489

Browse files
committed
Added logics to clear gradients on the replicas
1 parent 27e50f3 commit 97e1489

File tree

1 file changed

+7
-16
lines changed

1 file changed

+7
-16
lines changed

torch/nn/parallel/distributed.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,6 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,
140140
else:
141141
self._module_copies = [self.module]
142142

143-
<<<<<<< 314d632e04ee8ce4bc35a9bfc181cf2def3d5fab
144-
=======
145-
# TODO: different types need different buckets
146-
t = None
147-
for p in self.module.parameters():
148-
tp = type(p.data)
149-
if t is not None and t is not tp:
150-
raise ValueError("DistributedDataParallel requires all parameters' data to be of the same type")
151-
t = tp
152-
153-
>>>>>>> Added mixed precision support with nccl reduction bucketing
154143
# For NCCL backend, since every single NCCL call is asynchoronous, we
155144
# therefore directly enqueue all the NCCL reduction calls to the
156145
# default CUDA stream without spawning up other reduction threads.
@@ -159,7 +148,6 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,
159148
self._register_nccl_grad_hook()
160149
return
161150

162-
<<<<<<< 314d632e04ee8ce4bc35a9bfc181cf2def3d5fab
163151
bucket_bytes_cap = 1 * MB
164152

165153
# This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
@@ -168,9 +156,6 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,
168156
for dev_idx, module in enumerate(self._module_copies):
169157
param_buckets.append(list(_take_tensors(module.parameters(), bucket_bytes_cap)))
170158

171-
=======
172-
# Split parameters into buckets that will coalesce reductions
173-
>>>>>>> Added mixed precision support with nccl reduction bucketing
174159
self.bucket_sizes = []
175160
self.bucket_map = {}
176161

@@ -186,7 +171,6 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,
186171
bucket_param_type = param_tuple[0].type()
187172
# Only gloo and nccl support half-precision
188173
if bucket_param_type == torch.cuda.HalfTensor and \
189-
dist._backend != dist.dist_backend.NCCL and \
190174
dist._backend != dist.dist_backend.GLOO:
191175
raise RuntimeError("DistributedDataParallel currently only "
192176
"supports half precision parameters "
@@ -356,6 +340,13 @@ def reduction_fn_nccl():
356340
for grad, reduced in zip(grads_batch[0], grads_batch_reduced):
357341
grad.copy_(reduced)
358342

343+
# clear the gradients and save memory for replicas
344+
for module in self._module_copies[1:]:
345+
for param in module.parameters():
346+
if param.requires_grad:
347+
param.grad = None
348+
param.data.set_()
349+
359350
# Now register the reduction hook on the parameters
360351
for p in self.module.parameters():
361352
if not p.requires_grad:

0 commit comments

Comments
 (0)