-
Notifications
You must be signed in to change notification settings - Fork 26.3k
add mpi support for DDP #5919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add mpi support for DDP #5919
Conversation
|
CC @teng-li @pytorchbot test this please NB: I don't think CI is covering mpi atm; that should be fixed |
|
Thanks for the PR, but I think we should try to look for alternative ways to add support for CPU training. Right now this adds a third possible code path, for a third backend, and that's just going to be unmaintainable in the long run. It's not using most of hte code in this file, which suggests that it might be better to just make it a class in a separate file. |
|
@apaszke Thanks for the feedback, i'm looking to make it a class in a separate file, but i'm wondering what the class name should be. Can i use a new class name like torch.nn.parallel.MPIDistributedDataParallel()? The key-point is that the user have to call difference API for distributed CPU and GPU training. |
|
any proposal for this new class name to support distributed CPU training? |
|
@xhzhao torch.nn.parallel.DistributedDataParallelMPI() seems better |
|
Sorry for a late reply. Actually it shouldn't have MPI in the name. It will work with any backend that supports CPU, so it's more like |
|
@apaszke then the existing DDP should probably to be renamed to DistributedDataParallelGPU? |
|
Idk, we can do that if you feel strongly about it. I don't mind leaving it as is. |
|
@apaszke I will do it later when the DDP CPU gets merged |
|
@xhzhao Beside MPI, TCP and Gloo backend also supports CPU collective ops, you could later mention the supported backends in the module comments. |
|
@teng-li will do |
|
@pytorchbot test this please |
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs a test.
Also, is there any difference in allreduce_params from regular DistributedDataParallel and what you have here? It would be better to avoid code duplication, and refactor it into a shared function
torch/nn/parallel/distributed_cpu.py
Outdated
| def forward(self, *inputs, **kwargs): | ||
| if self.first_call: | ||
| self.weight_broadcast() | ||
| self.first_call = False |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/distributed_cpu.py
Outdated
| self.first_call = True | ||
|
|
||
| def allreduce_params(): | ||
| if (self.needs_reduction): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/distributed_cpu.py
Outdated
| buckets[tp] = [] | ||
| buckets[tp].append(param) | ||
|
|
||
| for tp in buckets: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/distributed_cpu.py
Outdated
| if param.requires_grad and param.grad is not None: | ||
| tp = type(param.data) | ||
| if tp not in buckets: | ||
| buckets[tp] = [] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/distributed_cpu.py
Outdated
| .. warning:: | ||
| This module works only with the ``mpi`` backends. | ||
| The other backends like ``gloo``, ``tcp`` are not tested yet. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@apaszke thanks for your feedback, i updated the code again:
|
|
any update for the code review? |
|
I will get back to you this week, sorry for the delay |
|
Well, I'm not into the name DDP-CPU. MPI != CPU As far as I know, there are at least 3 cuda-aware MPI implementations available. And I managed to compile pytorch with |
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two things and it should be good to merge.
| # Shuffle the input so that DDP input is different | ||
| input_cpu = input_cpu[torch.randperm(global_bs)] | ||
|
|
||
| self._barrier() |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/distributed_cpu.py
Outdated
| if param.requires_grad: | ||
| param.register_hook(allreduce_hook) | ||
|
|
||
| def weight_broadcast(self): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@Stonesjtu If you would like to use GPU training, why not just use existing DistributedDataParallel (DDP) module, I believe CUDA-aware MPI should work with it. My understanding is that CPU DDP is currently missing, and instead, let's just get a CPU-version of DDP working with all supported CPU backend, that makes more sense IMHO. @xhzhao I am wondering if you could also test your implementation with Gloo backend as well, that's gonna be super useful. |
|
@Stonesjtu this PR title is a little mismatch with the real target as our discussion went on.
|
| # Shuffle the input so that DDP input is different | ||
| input_cpu = input_cpu[torch.randperm(global_bs)] | ||
|
|
||
| self._barrier() |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| raise unittest.SkipTest("worldsize is too small to run group tests") | ||
|
|
||
| elif BACKEND == 'mpi': | ||
| WORLD_SIZE = os.environ['WORLD_SIZE'] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Thanks a lot @xhzhao! |
overview
this PR target is to add mpi support for torch.nn.parallel.DistributedDataParallel().
AFAIK, PyTorch DDP only support nccl and gloo backend, and i think it would be great to support mpi backend when the user get CPU only, especially for the researcher with supper computer access.
reference issue
code change:
i only add some lines in the init() and forward() function, without any change for the cuda code (except for the indent).
validation:
this code passed my test case for the mnist example: https://github.com/xhzhao/examples/tree/master/mnist
i'm also looking to add one more case in pytorch/test/test_distributed.py, but i could not find a method to launch mpi task from python. hope this problem will be fixed in the future.