Skip to content

Commit f65d51c

Browse files
committed
Use _take_tensors to simplify the logic
1 parent e4d40d6 commit f65d51c

File tree

1 file changed

+34
-49
lines changed

1 file changed

+34
-49
lines changed

torch/nn/parallel/distributed.py

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import math
33
import threading
44
import copy
5-
from collections import defaultdict
65

76
import torch
87
from torch.autograd import Variable
@@ -136,61 +135,47 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0,
136135
else:
137136
self._module_copies = [self.module]
138137

139-
# Split parameters into type buckets so that parameter sync (broadcast)
140-
# can operates on mixed parameter types. (e.g. mixed half and float)
141-
self.param_type_buckets = \
142-
defaultdict(lambda: [[] for _ in range(len(self.device_ids))])
143-
144-
for device_idx, module in enumerate(self._module_copies):
145-
for p in module.parameters():
146-
tp = p.type()
147-
if tp == torch.cuda.HalfTensor and \
148-
dist._backend != dist.dist_backend.NCCL and \
149-
dist._backend != dist.dist_backend.GLOO:
150-
raise RuntimeError("DistributedDataParallel currently only "
151-
"supports half precision parameters "
152-
"with NCCL backend")
153-
# Add the parameter into the type bucket
154-
self.param_type_buckets[tp][device_idx].append(p)
155-
156-
# TODO, adding mixed precision support in NCCL reduction code path
157-
# This is because NCCL backend doesn't support multiple reduction
158-
# bucket
159-
if len(self.param_type_buckets) > 1 and \
160-
dist._backend == dist.dist_backend.NCCL:
161-
raise RuntimeError("DistributedDataParallel currently doesn't "
162-
"support mixed precision type for NCCL backend")
163-
164-
# Split parameters into buckets that will coalesce reductions
165-
#
166-
# Note that previously we have already splitted parameters by the type.
167-
# Here, for each type, we further split each type of parameters into
168-
# reduction buckets so that each bucket will only have a single type
169-
# of parameters. Therefore subsequent all-reduce can be successful since
170-
# the reduction operation needs to operate on the same kind of data type
171-
self.bucket_sizes = []
172-
self.bucket_map = {}
173-
174138
# Currently NCCL backend only supports single reduction thread/bucket
175139
if dist._backend == dist.dist_backend.NCCL:
176140
bucket_bytes_cap = float('inf')
177141
else:
178142
bucket_bytes_cap = 1 * MB
179143

180-
for tp in self.param_type_buckets:
181-
# to init the first bucket immediately for each type
182-
bucket_bytes = bucket_bytes_cap
183-
for param_idx, param in enumerate(self.param_type_buckets[tp][0]):
184-
if not param.requires_grad:
144+
param_buckets = []
145+
# Split the parameters into buckets and by types as well
146+
for dev_idx, module in enumerate(self._module_copies):
147+
param_buckets.append(list(_take_tensors(module.parameters(), bucket_bytes_cap)))
148+
149+
self.bucket_sizes = []
150+
self.bucket_map = {}
151+
param_types = set()
152+
153+
for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)):
154+
self.bucket_sizes.append(0)
155+
for idx, param_tuple in enumerate(zip(*param_buckets_tuple)):
156+
if idx == 0:
157+
# Bucket parameter type tracking
158+
bucket_param_type = param_tuple[0].type()
159+
param_types.add(bucket_param_type)
160+
# Gloo is not supported due to fp16 performance
161+
if bucket_param_type == torch.cuda.HalfTensor and \
162+
dist._backend != dist.dist_backend.NCCL and \
163+
dist._backend != dist.dist_backend.GLOO:
164+
raise RuntimeError("DistributedDataParallel currently only "
165+
"supports half precision parameters "
166+
"with Nccl and Gloo backend")
167+
if not param_tuple[0].requires_grad:
185168
continue
186-
if bucket_bytes >= bucket_bytes_cap:
187-
self.bucket_sizes.append(0)
188-
bucket_bytes = 0
189-
for dev_idx in range(len(self.device_ids)):
190-
dev_param = self.param_type_buckets[tp][dev_idx][param_idx]
191-
self.bucket_map[dev_param] = len(self.bucket_sizes) - 1
192-
bucket_bytes += param.numel() * param.element_size()
193-
self.bucket_sizes[-1] += 1
169+
for p in param_tuple:
170+
self.bucket_map[p] = bucket_idx
171+
self.bucket_sizes[bucket_idx] += 1
172+
173+
# TODO, adding mixed precision support in NCCL reduction code path
174+
# This is because NCCL backend doesn't support multiple reduction
175+
# bucket.
176+
if len(param_types) > 1 and dist._backend == dist.dist_backend.NCCL:
177+
raise RuntimeError("DistributedDataParallel currently doesn't "
178+
"support mixed precision type for NCCL backend")
194179

195180
self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))]
196181
self.bucket_events = [[None] * len(self.device_ids) for _ in range(len(self.bucket_sizes))]

0 commit comments

Comments
 (0)