|
2 | 2 | import math |
3 | 3 | import threading |
4 | 4 | import copy |
5 | | -from collections import defaultdict |
6 | 5 |
|
7 | 6 | import torch |
8 | 7 | from torch.autograd import Variable |
@@ -136,61 +135,47 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0, |
136 | 135 | else: |
137 | 136 | self._module_copies = [self.module] |
138 | 137 |
|
139 | | - # Split parameters into type buckets so that parameter sync (broadcast) |
140 | | - # can operates on mixed parameter types. (e.g. mixed half and float) |
141 | | - self.param_type_buckets = \ |
142 | | - defaultdict(lambda: [[] for _ in range(len(self.device_ids))]) |
143 | | - |
144 | | - for device_idx, module in enumerate(self._module_copies): |
145 | | - for p in module.parameters(): |
146 | | - tp = p.type() |
147 | | - if tp == torch.cuda.HalfTensor and \ |
148 | | - dist._backend != dist.dist_backend.NCCL and \ |
149 | | - dist._backend != dist.dist_backend.GLOO: |
150 | | - raise RuntimeError("DistributedDataParallel currently only " |
151 | | - "supports half precision parameters " |
152 | | - "with NCCL backend") |
153 | | - # Add the parameter into the type bucket |
154 | | - self.param_type_buckets[tp][device_idx].append(p) |
155 | | - |
156 | | - # TODO, adding mixed precision support in NCCL reduction code path |
157 | | - # This is because NCCL backend doesn't support multiple reduction |
158 | | - # bucket |
159 | | - if len(self.param_type_buckets) > 1 and \ |
160 | | - dist._backend == dist.dist_backend.NCCL: |
161 | | - raise RuntimeError("DistributedDataParallel currently doesn't " |
162 | | - "support mixed precision type for NCCL backend") |
163 | | - |
164 | | - # Split parameters into buckets that will coalesce reductions |
165 | | - # |
166 | | - # Note that previously we have already splitted parameters by the type. |
167 | | - # Here, for each type, we further split each type of parameters into |
168 | | - # reduction buckets so that each bucket will only have a single type |
169 | | - # of parameters. Therefore subsequent all-reduce can be successful since |
170 | | - # the reduction operation needs to operate on the same kind of data type |
171 | | - self.bucket_sizes = [] |
172 | | - self.bucket_map = {} |
173 | | - |
174 | 138 | # Currently NCCL backend only supports single reduction thread/bucket |
175 | 139 | if dist._backend == dist.dist_backend.NCCL: |
176 | 140 | bucket_bytes_cap = float('inf') |
177 | 141 | else: |
178 | 142 | bucket_bytes_cap = 1 * MB |
179 | 143 |
|
180 | | - for tp in self.param_type_buckets: |
181 | | - # to init the first bucket immediately for each type |
182 | | - bucket_bytes = bucket_bytes_cap |
183 | | - for param_idx, param in enumerate(self.param_type_buckets[tp][0]): |
184 | | - if not param.requires_grad: |
| 144 | + param_buckets = [] |
| 145 | + # Split the parameters into buckets and by types as well |
| 146 | + for dev_idx, module in enumerate(self._module_copies): |
| 147 | + param_buckets.append(list(_take_tensors(module.parameters(), bucket_bytes_cap))) |
| 148 | + |
| 149 | + self.bucket_sizes = [] |
| 150 | + self.bucket_map = {} |
| 151 | + param_types = set() |
| 152 | + |
| 153 | + for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)): |
| 154 | + self.bucket_sizes.append(0) |
| 155 | + for idx, param_tuple in enumerate(zip(*param_buckets_tuple)): |
| 156 | + if idx == 0: |
| 157 | + # Bucket parameter type tracking |
| 158 | + bucket_param_type = param_tuple[0].type() |
| 159 | + param_types.add(bucket_param_type) |
| 160 | + # Gloo is not supported due to fp16 performance |
| 161 | + if bucket_param_type == torch.cuda.HalfTensor and \ |
| 162 | + dist._backend != dist.dist_backend.NCCL and \ |
| 163 | + dist._backend != dist.dist_backend.GLOO: |
| 164 | + raise RuntimeError("DistributedDataParallel currently only " |
| 165 | + "supports half precision parameters " |
| 166 | + "with Nccl and Gloo backend") |
| 167 | + if not param_tuple[0].requires_grad: |
185 | 168 | continue |
186 | | - if bucket_bytes >= bucket_bytes_cap: |
187 | | - self.bucket_sizes.append(0) |
188 | | - bucket_bytes = 0 |
189 | | - for dev_idx in range(len(self.device_ids)): |
190 | | - dev_param = self.param_type_buckets[tp][dev_idx][param_idx] |
191 | | - self.bucket_map[dev_param] = len(self.bucket_sizes) - 1 |
192 | | - bucket_bytes += param.numel() * param.element_size() |
193 | | - self.bucket_sizes[-1] += 1 |
| 169 | + for p in param_tuple: |
| 170 | + self.bucket_map[p] = bucket_idx |
| 171 | + self.bucket_sizes[bucket_idx] += 1 |
| 172 | + |
| 173 | + # TODO, adding mixed precision support in NCCL reduction code path |
| 174 | + # This is because NCCL backend doesn't support multiple reduction |
| 175 | + # bucket. |
| 176 | + if len(param_types) > 1 and dist._backend == dist.dist_backend.NCCL: |
| 177 | + raise RuntimeError("DistributedDataParallel currently doesn't " |
| 178 | + "support mixed precision type for NCCL backend") |
194 | 179 |
|
195 | 180 | self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))] |
196 | 181 | self.bucket_events = [[None] * len(self.device_ids) for _ in range(len(self.bucket_sizes))] |
|
0 commit comments