-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Added mixed-precision support in distributed training #4891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
db6011d to
fc2c0d3
Compare
|
@teng-li, @csarofeen has a simple distributed module that supports mixed precision https://github.com/csarofeen/examples/blob/dist_fp16/imagenet/distributed.py, he was planning to push it to pytorch core soon. It is just a few lines of code, much easier to follow than DistributedDataParallel. Like yours here, it all-reduces all the parameters at the end of the iteration, and unlike DistributedDataParallel it does not have the limitation that all the parameters have to be updated at each iteration. It does not sync buffers and you are right, it is a good idea to have an option to sync or not sync buffers. We also tested it with resnet on DGX1. |
|
@ngimel Thanks for the previous code review. Yes, I did see @csarofeen 's DDP code and it was a lot simpler. But our current DistributedDataParallel model has also covered other important use cases that can be beneficial with other backends such as gloo. And with the current DDP implementation, if we use single GPU binding and launch a DDP process per GPU, I didn't see much perf degradation by binding a single GPU to our current DDP implementation with multiple processes compared to Christian's DDP either. Please see for the perf comparisons with Christian's simpler version of DDP. Also the current DDP implementation also covers the single process multiple GPU case, which is another use case for users other than multi-process use case. So I guess it would probably be Ok to also add this mixed-precision support into the current DDP so that other use cases (such as gloo on ethernet with perf advantage due to the multi-threading reduction) can be covered as well. |
b54884e to
a71c1eb
Compare
|
Since the single GPU binding will go to a different small code path as I planned. I think this PR should be OK to proceed forward. @apaszke |
torch/nn/parallel/distributed.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/distributed.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/distributed.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
88b8cdd to
52dcc1d
Compare
|
@apaszke Comments addressed. |
torch/nn/parallel/distributed.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
52dcc1d to
e83e35c
Compare
|
Fixed merge conflict. @apaszke Please also take a look at this PR. |
torch/nn/parallel/distributed.py
Outdated
| for tensor, param in \ | ||
| zip(tensors, | ||
| self.param_type_buckets[tp][dev_idx]): | ||
| param.data.set_(tensor) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/distributed.py
Outdated
| dev_param = self.param_type_buckets[tp][dev_idx][param_idx] | ||
| self.bucket_map[dev_param] = len(self.bucket_sizes) - 1 | ||
| bucket_bytes += param.numel() * param.element_size() | ||
| self.bucket_sizes[-1] += 1 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/distributed.py
Outdated
| dev_param = self.param_type_buckets[tp][dev_idx][param_idx] | ||
| self.bucket_map[dev_param] = len(self.bucket_sizes) - 1 | ||
| bucket_bytes += param.numel() * param.element_size() | ||
| self.bucket_sizes[-1] += 1 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
e92a523 to
f65d51c
Compare
|
@apaszke refactored |
torch/nn/parallel/distributed.py
Outdated
| # Bucket parameter type tracking | ||
| bucket_param_type = param_tuple[0].type() | ||
| param_types.add(bucket_param_type) | ||
| # Gloo is not supported due to fp16 performance |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| bucket_bytes += p.numel() * p.element_size() | ||
| self.bucket_sizes[-1] += 1 | ||
| self.bucket_map[p] = bucket_idx | ||
| self.bucket_sizes[bucket_idx] += 1 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| bucket_bytes += p.numel() * p.element_size() | ||
| self.bucket_sizes[-1] += 1 | ||
| self.bucket_map[p] = bucket_idx | ||
| self.bucket_sizes[bucket_idx] += 1 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/distributed.py
Outdated
| # module buffer sync | ||
| buffers = list(self.module._all_buffers()) | ||
| if len(buffers) > 0: | ||
| if self.broadcast_buffers and len(buffers) > 0: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| self.bucket_sizes.append(0) | ||
| # Now, we transpose again, so we iterate over bucket_elems, but getting tuples | ||
| # of params from each device. | ||
| for idx, param_tuple in enumerate(zip(*param_buckets_tuple)): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This PR added mixed-precision support in distributed training.
Basically, this is done by
Note that the NCCL backend currently only supports a single reduction bucket, so this support will be further added with the other PR that makes NCCL backend a separate code path.
Half-precision support is only enabled for NCCL and GLOO backend.
Also added the option to the constructor to either sync or not sync the module buffers as an option since for ResNet we don't need to sync buffer and can hit the same accuracy.
Tested by running the distributed training on DGX1. Bucketing are also tested by printing out the value.
Tested mixed prevision for both Nccl and Gloo as well.