Skip to content

Conversation

@teng-li
Copy link
Contributor

@teng-li teng-li commented Jan 27, 2018

This PR added mixed-precision support in distributed training.

Basically, this is done by

  1. bucketing parameters by data types so that intra-node broadcast can operate on different data types.
  2. bucketing parameters by data types so that gradient all reduction can operate on different data types.

Note that the NCCL backend currently only supports a single reduction bucket, so this support will be further added with the other PR that makes NCCL backend a separate code path.

Half-precision support is only enabled for NCCL and GLOO backend.

Also added the option to the constructor to either sync or not sync the module buffers as an option since for ResNet we don't need to sync buffer and can hit the same accuracy.

Tested by running the distributed training on DGX1. Bucketing are also tested by printing out the value.

Tested mixed prevision for both Nccl and Gloo as well.

@ngimel
Copy link
Collaborator

ngimel commented Jan 27, 2018

@teng-li, @csarofeen has a simple distributed module that supports mixed precision https://github.com/csarofeen/examples/blob/dist_fp16/imagenet/distributed.py, he was planning to push it to pytorch core soon. It is just a few lines of code, much easier to follow than DistributedDataParallel. Like yours here, it all-reduces all the parameters at the end of the iteration, and unlike DistributedDataParallel it does not have the limitation that all the parameters have to be updated at each iteration. It does not sync buffers and you are right, it is a good idea to have an option to sync or not sync buffers. We also tested it with resnet on DGX1.
And thank you for your work with nccl backend!

@teng-li
Copy link
Contributor Author

teng-li commented Jan 27, 2018

@ngimel Thanks for the previous code review. Yes, I did see @csarofeen 's DDP code and it was a lot simpler. But our current DistributedDataParallel model has also covered other important use cases that can be beneficial with other backends such as gloo. And with the current DDP implementation, if we use single GPU binding and launch a DDP process per GPU, I didn't see much perf degradation by binding a single GPU to our current DDP implementation with multiple processes compared to Christian's DDP either. Please see

#4870

for the perf comparisons with Christian's simpler version of DDP. Also the current DDP implementation also covers the single process multiple GPU case, which is another use case for users other than multi-process use case. So I guess it would probably be Ok to also add this mixed-precision support into the current DDP so that other use cases (such as gloo on ethernet with perf advantage due to the multi-threading reduction) can be covered as well.

CC: @apaszke @soumith

@teng-li teng-li force-pushed the mixed-prec-dist branch 2 times, most recently from b54884e to a71c1eb Compare January 30, 2018 22:33
@teng-li teng-li changed the title Added mixed-precision support in distributed training Added mixed-precision support in distributed training (DON'T MERGE) Jan 30, 2018
@teng-li teng-li changed the title Added mixed-precision support in distributed training (DON'T MERGE) Added mixed-precision support in distributed training Feb 3, 2018
@teng-li
Copy link
Contributor Author

teng-li commented Feb 3, 2018

Since the single GPU binding will go to a different small code path as I planned. I think this PR should be OK to proceed forward. @apaszke

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@teng-li
Copy link
Contributor Author

teng-li commented Feb 6, 2018

@apaszke Comments addressed.

This comment was marked as off-topic.

@teng-li
Copy link
Contributor Author

teng-li commented Feb 13, 2018

Fixed merge conflict. @apaszke Please also take a look at this PR.

for tensor, param in \
zip(tensors,
self.param_type_buckets[tp][dev_idx]):
param.data.set_(tensor)

This comment was marked as off-topic.

This comment was marked as off-topic.

dev_param = self.param_type_buckets[tp][dev_idx][param_idx]
self.bucket_map[dev_param] = len(self.bucket_sizes) - 1
bucket_bytes += param.numel() * param.element_size()
self.bucket_sizes[-1] += 1

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

dev_param = self.param_type_buckets[tp][dev_idx][param_idx]
self.bucket_map[dev_param] = len(self.bucket_sizes) - 1
bucket_bytes += param.numel() * param.element_size()
self.bucket_sizes[-1] += 1

This comment was marked as off-topic.

@teng-li
Copy link
Contributor Author

teng-li commented Feb 18, 2018

@apaszke refactored

# Bucket parameter type tracking
bucket_param_type = param_tuple[0].type()
param_types.add(bucket_param_type)
# Gloo is not supported due to fp16 performance

This comment was marked as off-topic.

This comment was marked as off-topic.

bucket_bytes += p.numel() * p.element_size()
self.bucket_sizes[-1] += 1
self.bucket_map[p] = bucket_idx
self.bucket_sizes[bucket_idx] += 1

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

bucket_bytes += p.numel() * p.element_size()
self.bucket_sizes[-1] += 1
self.bucket_map[p] = bucket_idx
self.bucket_sizes[bucket_idx] += 1

This comment was marked as off-topic.

# module buffer sync
buffers = list(self.module._all_buffers())
if len(buffers) > 0:
if self.broadcast_buffers and len(buffers) > 0:

This comment was marked as off-topic.

This comment was marked as off-topic.

self.bucket_sizes.append(0)
# Now, we transpose again, so we iterate over bucket_elems, but getting tuples
# of params from each device.
for idx, param_tuple in enumerate(zip(*param_buckets_tuple)):

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke apaszke merged commit 4b8f4fc into pytorch:master Feb 21, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants