@@ -140,17 +140,6 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,
140140 else :
141141 self ._module_copies = [self .module ]
142142
143- < << << << 314 d632e04ee8ce4bc35a9bfc181cf2def3d5fab
144- == == == =
145- # TODO: different types need different buckets
146- t = None
147- for p in self .module .parameters ():
148- tp = type(p .data )
149- if t is not None and t is not tp :
150- raise ValueError ("DistributedDataParallel requires all parameters ' data to be of the same type")
151- t = tp
152-
153- >> >> >> > Added mixed precision support with nccl reduction bucketing
154143 # For NCCL backend, since every single NCCL call is asynchoronous, we
155144 # therefore directly enqueue all the NCCL reduction calls to the
156145 # default CUDA stream without spawning up other reduction threads.
@@ -159,7 +148,6 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,
159148 self ._register_nccl_grad_hook ()
160149 return
161150
162- << << << < 314 d632e04ee8ce4bc35a9bfc181cf2def3d5fab
163151 bucket_bytes_cap = 1 * MB
164152
165153 # This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
@@ -168,9 +156,6 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,
168156 for dev_idx , module in enumerate (self ._module_copies ):
169157 param_buckets .append (list (_take_tensors (module .parameters (), bucket_bytes_cap )))
170158
171- == == == =
172- # Split parameters into buckets that will coalesce reductions
173- >> >> >> > Added mixed precision support with nccl reduction bucketing
174159 self .bucket_sizes = []
175160 self .bucket_map = {}
176161
@@ -186,7 +171,6 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,
186171 bucket_param_type = param_tuple [0 ].type ()
187172 # Only gloo and nccl support half-precision
188173 if bucket_param_type == torch .cuda .HalfTensor and \
189- dist ._backend != dist .dist_backend .NCCL and \
190174 dist ._backend != dist .dist_backend .GLOO :
191175 raise RuntimeError ("DistributedDataParallel currently only "
192176 "supports half precision parameters "
@@ -356,6 +340,13 @@ def reduction_fn_nccl():
356340 for grad , reduced in zip (grads_batch [0 ], grads_batch_reduced ):
357341 grad .copy_ (reduced )
358342
343+ # clear the gradients and save memory for replicas
344+ for module in self ._module_copies [1 :]:
345+ for param in module .parameters ():
346+ if param .requires_grad :
347+ param .grad = None
348+ param .data .set_ ()
349+
359350 # Now register the reduction hook on the parameters
360351 for p in self .module .parameters ():
361352 if not p .requires_grad :
0 commit comments