|
2 | 2 | import math |
3 | 3 | import threading |
4 | 4 | import copy |
| 5 | +from collections import defaultdict |
5 | 6 |
|
6 | 7 | import torch |
7 | 8 | from torch.autograd import Variable |
@@ -137,49 +138,58 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0, |
137 | 138 |
|
138 | 139 | # Split parameters into type buckets so that parameter sync (broadcast) |
139 | 140 | # can operates on mixed parameter types. (e.g. mixed half and float) |
140 | | - self.param_type_buckets = {} |
| 141 | + self.param_type_buckets = \ |
| 142 | + defaultdict(lambda: [[] for _ in range(len(self.device_ids))]) |
| 143 | + |
141 | 144 | for device_idx, module in enumerate(self._module_copies): |
142 | 145 | for p in module.parameters(): |
143 | | - tp = type(p.data) |
| 146 | + tp = p.type() |
144 | 147 | if tp == torch.cuda.HalfTensor and \ |
145 | | - dist._backend != dist.dist_backend.NCCL: |
| 148 | + dist._backend != dist.dist_backend.NCCL and \ |
| 149 | + dist._backend != dist.dist_backend.GLOO: |
146 | 150 | raise RuntimeError("DistributedDataParallel currently only " |
147 | 151 | "supports half precision parameters " |
148 | 152 | "with NCCL backend") |
149 | | - if tp not in self.param_type_buckets: |
150 | | - self.param_type_buckets[tp] = \ |
151 | | - [[] for _ in range(len(self.device_ids))] |
152 | 153 | # Add the parameter into the type bucket |
153 | 154 | self.param_type_buckets[tp][device_idx].append(p) |
154 | 155 |
|
| 156 | + # TODO, adding mixed precision support in NCCL reduction code path |
| 157 | + # This is because NCCL backend doesn't support multiple reduction |
| 158 | + # bucket |
| 159 | + if len(self.param_type_buckets) > 1 and \ |
| 160 | + dist._backend == dist.dist_backend.NCCL: |
| 161 | + raise RuntimeError("DistributedDataParallel currently doesn't " |
| 162 | + "support mixed precision type for NCCL backend") |
| 163 | + |
155 | 164 | # Split parameters into buckets that will coalesce reductions |
156 | 165 | # |
157 | | - # Note that the NCCL backend currently only supports a single reduction |
158 | | - # bucket, so instead of splitting different Tensor types (half, float, |
159 | | - # double, etc) into separate buckets, which will form multipel buckets, |
160 | | - # we will split the parameters into reduction buckets regardless of |
161 | | - # the data types here. |
162 | | - # |
163 | | - # For each reductions bucket, at the reduction time, we will further |
164 | | - # split the gradients of different types into the each individual type |
165 | | - # bucket so that different types of gradients can be reduced. |
166 | | - |
| 166 | + # Note that previously we have already splitted parameters by the type. |
| 167 | + # Here, for each type, we further split each type of parameters into |
| 168 | + # reduction buckets so that each bucket will only have a single type |
| 169 | + # of parameters. Therefore subsequent all-reduce can be successful since |
| 170 | + # the reduction operation needs to operate on the same kind of data type |
167 | 171 | self.bucket_sizes = [] |
168 | 172 | self.bucket_map = {} |
| 173 | + |
169 | 174 | # Currently NCCL backend only supports single reduction thread/bucket |
170 | 175 | if dist._backend == dist.dist_backend.NCCL: |
171 | 176 | bucket_bytes_cap = float('inf') |
172 | 177 | else: |
173 | 178 | bucket_bytes_cap = 1 * MB |
174 | | - bucket_bytes = bucket_bytes_cap # to init the first bucket immediately |
175 | | - for param_tuple in zip(*map(lambda m: m.parameters(), self._module_copies)): |
176 | | - if param_tuple[0].requires_grad: |
| 179 | + |
| 180 | + for tp in self.param_type_buckets: |
| 181 | + # to init the first bucket immediately for each type |
| 182 | + bucket_bytes = bucket_bytes_cap |
| 183 | + for param_idx, param in enumerate(self.param_type_buckets[tp][0]): |
| 184 | + if not param.requires_grad: |
| 185 | + continue |
177 | 186 | if bucket_bytes >= bucket_bytes_cap: |
178 | 187 | self.bucket_sizes.append(0) |
179 | 188 | bucket_bytes = 0 |
180 | | - for p in param_tuple: |
181 | | - self.bucket_map[p] = len(self.bucket_sizes) - 1 |
182 | | - bucket_bytes += p.numel() * p.element_size() |
| 189 | + for dev_idx in range(len(self.device_ids)): |
| 190 | + dev_param = self.param_type_buckets[tp][dev_idx][param_idx] |
| 191 | + self.bucket_map[dev_param] = len(self.bucket_sizes) - 1 |
| 192 | + bucket_bytes += param.numel() * param.element_size() |
183 | 193 | self.bucket_sizes[-1] += 1 |
184 | 194 |
|
185 | 195 | self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))] |
@@ -249,10 +259,12 @@ def _sync_params(self): |
249 | 259 | result = broadcast_coalesced(params, |
250 | 260 | self.device_ids, |
251 | 261 | self.broadcast_bucket_size) |
252 | | - for tensors, device_id in zip(result[1:], self.device_ids[1:]): |
| 262 | + for idx, tensors in enumerate(result[1:]): |
| 263 | + # Just to make it clear |
| 264 | + dev_idx = idx + 1 |
253 | 265 | for tensor, param in \ |
254 | 266 | zip(tensors, |
255 | | - self.param_type_buckets[tp][device_id]): |
| 267 | + self.param_type_buckets[tp][dev_idx]): |
256 | 268 | param.data.set_(tensor) |
257 | 269 |
|
258 | 270 | # module buffer sync |
@@ -373,53 +385,27 @@ def _reduction_thread_fn(queue, group_id, device_ids, reduction_streams, nccl_st |
373 | 385 | def _process_batch(): |
374 | 386 | dev_grad_batch, dev_events, job_event = queue.get() |
375 | 387 | dev_coalesced = [] |
376 | | - # For bucketing gradients with different data types |
377 | | - type_buckets = {} |
378 | | - |
379 | | - # Bucket the grad batch based on the data types: float, half etc |
380 | | - for dev_idx, grad_batch in enumerate(dev_grad_batch): |
381 | | - for grad in grad_batch: |
382 | | - tp = type(grad) |
383 | | - if tp not in type_buckets: |
384 | | - type_buckets[tp] = [[] for _ in range(len(device_ids))] |
385 | | - type_buckets[tp][dev_idx].append(grad) |
386 | | - |
387 | | - # Reducing for each data type if we have mixed-precision gradients |
388 | | - for tp in type_buckets: |
389 | | - tp_dev_grad_batch = type_buckets[tp] |
390 | | - # Coalesce the tensors on all devices and start a local |
391 | | - # reduction |
392 | | - for dev_id, tp_grad_batch, event, stream in zip( |
393 | | - device_ids, |
394 | | - tp_dev_grad_batch, |
395 | | - dev_events, |
396 | | - reduction_streams): |
397 | | - |
398 | | - with torch.cuda.device(dev_id), torch.cuda.stream(stream): |
399 | | - stream.wait_event(event) |
400 | | - coalesced = _flatten_dense_tensors(tp_grad_batch) |
401 | | - dev_coalesced.append(coalesced) |
402 | | - |
403 | | - # Wait for all copies to complete before starting the |
404 | | - # NCCL kernel |
405 | | - for stream in reduction_streams: |
406 | | - stream.synchronize() |
407 | | - nccl.reduce(dev_coalesced, root=0, streams=nccl_streams) |
408 | | - |
409 | | - # From now on we're only going to work on the |
410 | | - # first device (from device_ids) |
411 | | - tp_grad_batch = tp_dev_grad_batch[0] |
412 | | - coalesced = dev_coalesced[0] |
413 | | - reduce_stream = reduction_streams[0] |
414 | | - with torch.cuda.stream(reduce_stream): |
415 | | - reduce_stream.wait_stream(nccl_streams[0]) |
416 | | - coalesced /= dist.get_world_size() |
417 | | - dist.all_reduce(coalesced, group=group_id) |
418 | | - for tp_grad, reduced in zip( |
419 | | - tp_grad_batch, |
420 | | - _unflatten_dense_tensors(coalesced, tp_grad_batch)): |
421 | | - tp_grad.copy_(reduced) |
422 | | - |
| 388 | + # Coalesce the tensors on all devices and start a local reduction |
| 389 | + for dev_id, grad_batch, event, stream in zip(device_ids, dev_grad_batch, dev_events, reduction_streams): |
| 390 | + with torch.cuda.device(dev_id), torch.cuda.stream(stream): |
| 391 | + stream.wait_event(event) |
| 392 | + coalesced = _flatten_dense_tensors(grad_batch) |
| 393 | + dev_coalesced.append(coalesced) |
| 394 | + # Wait for all copies to complete before starting the NCCL kernel |
| 395 | + for stream in reduction_streams: |
| 396 | + stream.synchronize() |
| 397 | + nccl.reduce(dev_coalesced, root=0, streams=nccl_streams) |
| 398 | + |
| 399 | + # From now on we're only going to work on the first device (from device_ids) |
| 400 | + grad_batch = dev_grad_batch[0] |
| 401 | + coalesced = dev_coalesced[0] |
| 402 | + reduce_stream = reduction_streams[0] |
| 403 | + with torch.cuda.stream(reduce_stream): |
| 404 | + reduce_stream.wait_stream(nccl_streams[0]) |
| 405 | + coalesced /= dist.get_world_size() |
| 406 | + dist.all_reduce(coalesced, group=group_id) |
| 407 | + for grad, reduced in zip(grad_batch, _unflatten_dense_tensors(coalesced, grad_batch)): |
| 408 | + grad.copy_(reduced) |
423 | 409 | job_event.set() |
424 | 410 |
|
425 | 411 | with torch.cuda.device(device_ids[0]): |
|
0 commit comments