Skip to content

Conversation

@rohan-varma
Copy link
Contributor

@rohan-varma rohan-varma commented Aug 5, 2020

Stack from ghstack:

Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

Approach

  1. Add a context manager that is owned by class DistributedDataParallel to coordinate the below process.
  2. In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
  3. When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
    a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm.
  4. We provide an option of whether we should divide by the initial world_size or effective world_size when some ranks are gone (default to initial world_size). If dividing by effective world_size, we adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size.
  5. At the end of training, the last joined process is selected to be the "authoritative" model copy and broadcasts its parameters.

We also make the following smaller changes to support the above:

  • Add a rank argument to _distributed_broadcast_coalesced to specify which rank should do the broadcast, instead of forcing rank 0. This is needed because we cannot select rank 0 arbitrarily in the join-mode.
  • Add a helper function to DistributedDataParallel which will have all processes agree on a common rank based on some condition. This common rank is then used for broadcasting final model params and module buffers throughout training.
  • Expose several helper methods on Reducer such as getters for the Reducers Buckets and the ability to invoke rebuildBuckets() from Python, to support "shadowing" collective calls in join mode.

How is it tested?

We have tests covering the following models/scenarios:

  • Simple linear model
  • Large convolutional model
  • Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function will_sync_module_buffers and ensure this is true for ResNet (due to batchnorm)
  • Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
  • Model with unused params (with find unused parameters=True)
  • Scenarios where different processes iterate for a varying number of different iterations.
  • Test consistency in tie-breaking when multiple ranks are the last ones to join
  • Test gradient division by the effective world_size (no. of unjoined processes) and the static world_size
  • Test that exceptions during training are correctly propagated by the context manager
  • Test expected behavior when the manager is disabled with enable=False (for debug purposes)
  • Test expected behavior when > 1 process joins early (at different iterations)
  • Test model equivalence to local training when used with join API.

How to run the tests

The main test can be run with touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" python test/distributed/test_distributed.py -v TestDistBackend.test_ddp_uneven_inputs

Performance implications

Trunk vs PR patched, 32 GPUs, batch size = 32

P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s

join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs

P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s

join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs

P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s

join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32

P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s

join(enable=True) with uneven inputs (offset = 2000), 8 GPUs, batch size = 32

P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s

The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.

join(enable=True) with uneven inputs (offset = 10), 8 GPUs, batch size = 32

P50 forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.120 265/s
P50 backwards only batch latency & total QPS: 0.087 367/s vs 0.087 367/s
This means that there is only a difference of 10 in the uneven inputs, i.e. the early joined ranks only iterate 10 times less than the ones that iterate for the full N, instead of an all-or-nothing in the above tests.

Limitations

  1. This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
  2. This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the join class only shadows the broadcast for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
  3. Has not been tested with the DDP comm. hook as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
  4. Has not been thoroughly tested with DDP + RPC. We plan to add support for this in follow up PRs.

Differential Revision: D22893859

NOTE FOR REVIEWERS: This PR has internal Facebook specific changes or comments, please review them on Phabricator!

Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy

We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!

[ghstack-poisoned]
@dr-ci
Copy link

dr-ci bot commented Aug 5, 2020

💊 CI failures summary and remediations

As of commit adc6bf2 (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 2/2 non-CircleCI failure(s)

Extra GitHub checks: 1 failed


codecov.io: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 149 times.

Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
    a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm. 
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy

We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!

[ghstack-poisoned]
Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
    a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm. 
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy

We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!

[ghstack-poisoned]
@rohan-varma rohan-varma marked this pull request as ready for review August 5, 2020 06:20
Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
    a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm. 
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy

We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!

[ghstack-poisoned]
Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
    a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm. 
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy and broadcasts its parameters.

We also make the following smaller changes to support the above:
- Add a `rank` argument to `_distributed_broadcast_coalesced` to specify which rank should do the broadcast, instead of forcing rank 0. This is needed because we cannot select rank 0 arbitrarily in the join-mode.
- Add a helper function to `DistributedDataParallel` which will have all processes agree on a common rank based on some condition. This common rank is then used for broadcasting final model params and module buffers throughout training.
- Expose several helper methods on `Reducer` such as getters for the `Reducer`s `Bucket`s and the ability to invoke `rebuildBuckets()` from Python, to support "shadowing" collective calls in join mode.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Aug 5, 2020
Pull Request resolved: #42577

Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy

We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
ghstack-source-id: 109225177

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!
Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
    a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm. 
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy and broadcasts its parameters.

We also make the following smaller changes to support the above:
- Add a `rank` argument to `_distributed_broadcast_coalesced` to specify which rank should do the broadcast, instead of forcing rank 0. This is needed because we cannot select rank 0 arbitrarily in the join-mode.
- Add a helper function to `DistributedDataParallel` which will have all processes agree on a common rank based on some condition. This common rank is then used for broadcasting final model params and module buffers throughout training.
- Expose several helper methods on `Reducer` such as getters for the `Reducer`s `Bucket`s and the ability to invoke `rebuildBuckets()` from Python, to support "shadowing" collective calls in join mode.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Aug 5, 2020
Pull Request resolved: #42577

Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy

We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
ghstack-source-id: 109227369

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!
Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
    a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm. 
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy and broadcasts its parameters.

We also make the following smaller changes to support the above:
- Add a `rank` argument to `_distributed_broadcast_coalesced` to specify which rank should do the broadcast, instead of forcing rank 0. This is needed because we cannot select rank 0 arbitrarily in the join-mode.
- Add a helper function to `DistributedDataParallel` which will have all processes agree on a common rank based on some condition. This common rank is then used for broadcasting final model params and module buffers throughout training.
- Expose several helper methods on `Reducer` such as getters for the `Reducer`s `Bucket`s and the ability to invoke `rebuildBuckets()` from Python, to support "shadowing" collective calls in join mode.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!

[ghstack-poisoned]
Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager that is owned by `class DistributedDataParallel` to coordinate the below process.
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
    a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm. 
4) We provide an option of whether we should divide by the initial world_size or effective world_size when some ranks are gone (default to initial world_size). If dividing by effective world_size, we adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy and broadcasts its parameters.

We also make the following smaller changes to support the above:
- Add a `rank` argument to `_distributed_broadcast_coalesced` to specify which rank should do the broadcast, instead of forcing rank 0. This is needed because we cannot select rank 0 arbitrarily in the join-mode.
- Add a helper function to `DistributedDataParallel` which will have all processes agree on a common rank based on some condition. This common rank is then used for broadcasting final model params and module buffers throughout training.
- Expose several helper methods on `Reducer` such as getters for the `Reducer`s `Bucket`s and the ability to invoke `rebuildBuckets()` from Python, to support "shadowing" collective calls in join mode.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test gradient division by the effective world_size (no. of unjoined processes) and the static world_size
- [x] Test that exceptions during training are correctly propagated by the context manager
- [x] Test expected behavior when the manager is disabled with `enable=False` (for debug purposes)
- [x] Test expected behavior when > 1 process joins early (at different iterations)
- [x] Test model equivalence to local training when used with join API.

#### How to run the tests
The main test can be run with `touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" python test/distributed/test_distributed.py -v TestDistBackend.test_ddp_uneven_inputs`

#### Performance implications

###### Trunk vs PR patched, 32 GPUs, batch size = 32
P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s

###### join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs
P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s

###### join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs
P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s

The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.

###### join(enable=True) with uneven inputs (offset = 10), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.120   265/s
P50 backwards only batch latency & total QPS: 0.087 367/s vs 0.087   367/s
This means that there is only a difference of 10 in the uneven inputs, i.e. the early joined ranks only iterate 10 times less than the ones that iterate for the full N, instead of an all-or-nothing in the above tests.


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
4) Has not been thoroughly tested with DDP + RPC. We plan to add support for this in follow up PRs. 

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Aug 25, 2020
Pull Request resolved: #42577

Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy

We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)

#### Performance implications

###### Trunk vs PR patched, 32 GPUs, batch size = 32
P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s

###### join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs
P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s

###### join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs
P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s

###### join(enable=True) with uneven inputs ((offset = 2000)), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s

The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.




#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
ghstack-source-id: 110666640

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!
Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for adding this!

def forward(self, *inputs, **kwargs):
if self.ddp_join_enabled:
ones = torch.ones(
1, device=self.device_ids[0] if self.device_type != "cpu" else "cpu"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does device = list(self.module.parameters())[0].device work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work, but the reason I wasn't using it is because I wasn't sure if we want the overhead of creating the list only tot take the device_id of the first parameter. Maybe we could do this once in the constructor and use that as the device for the process wherever we need it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Yep, doing it in the ctor sounds good to me.

Args:
divide_by_initial_world_size (bool): If ``True``, will divide
gradients by the initial world_size DDP training was launched
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double quote world_size?

@zhaojuanmao
Copy link
Contributor

@rohan-varma the memory reduction diff is being reverted in #43557, there may be merge conflict again (sorry about that).... I commented on why the diff was reverted and what was the followup.

rohan-varma added a commit that referenced this pull request Aug 25, 2020
…Us required in

skipped tests"

skipped tests**

Closes #41378.
#41973 enhanced the skip decorators to
report the right no. of GPUs required, but this information was not passed to
the main process where the message is actually displayed. This PR uses a
`multiprocessing.Manager()` so that the dictionary modification is reflected
correctly in the main process.

With this diff, we can run a test in #42577 that requires 4 GPUs on a 2 GPU machine, and we get the expected message:

```
test_ddp_uneven_inputs_replicated_error (test_distributed.TestDistBackend) ... skipped 'Need at least 4 CUDA devices'
```

Differential Revision: [D23285790](https://our.internmc.facebook.com/intern/diff/D23285790/)

[ghstack-poisoned]
@rohan-varma
Copy link
Contributor Author

@rohan-varma the memory reduction diff is being reverted in #43557, there may be merge conflict again (sorry about that).... I commented on why the diff was reverted and what was the followup.

Thanks for notifying! I will hold off until the revert goes through.

facebook-github-bot pushed a commit that referenced this pull request Aug 25, 2020
…reduce (#43543)

Summary:
Pull Request resolved: #43543

Closes #14691. This is not needed in the multiple outputs case, because gloo allreduce
will broadcast the result tensor to all the outputs. See
pytorch/gloo#152 and commit
pytorch/gloo@9cabb5a
for more details. Came across this when debugging #42577.

This effectively reverts #14688 while still keeping the tests.

Tested by ensuring `test_allreduce_basics` in `test_c10d.py` still works as expected.
ghstack-source-id: 110636498

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D23173945

fbshipit-source-id: d1ae08f84b4ac9919c53080949b8fffcb2fe63a8
Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager that is owned by `class DistributedDataParallel` to coordinate the below process.
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
    a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm. 
4) We provide an option of whether we should divide by the initial world_size or effective world_size when some ranks are gone (default to initial world_size). If dividing by effective world_size, we adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy and broadcasts its parameters.

We also make the following smaller changes to support the above:
- Add a `rank` argument to `_distributed_broadcast_coalesced` to specify which rank should do the broadcast, instead of forcing rank 0. This is needed because we cannot select rank 0 arbitrarily in the join-mode.
- Add a helper function to `DistributedDataParallel` which will have all processes agree on a common rank based on some condition. This common rank is then used for broadcasting final model params and module buffers throughout training.
- Expose several helper methods on `Reducer` such as getters for the `Reducer`s `Bucket`s and the ability to invoke `rebuildBuckets()` from Python, to support "shadowing" collective calls in join mode.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test gradient division by the effective world_size (no. of unjoined processes) and the static world_size
- [x] Test that exceptions during training are correctly propagated by the context manager
- [x] Test expected behavior when the manager is disabled with `enable=False` (for debug purposes)
- [x] Test expected behavior when > 1 process joins early (at different iterations)
- [x] Test model equivalence to local training when used with join API.

#### How to run the tests
The main test can be run with `touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" python test/distributed/test_distributed.py -v TestDistBackend.test_ddp_uneven_inputs`

#### Performance implications

###### Trunk vs PR patched, 32 GPUs, batch size = 32
P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s

###### join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs
P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s

###### join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs
P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s

The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.

###### join(enable=True) with uneven inputs (offset = 10), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.120   265/s
P50 backwards only batch latency & total QPS: 0.087 367/s vs 0.087   367/s
This means that there is only a difference of 10 in the uneven inputs, i.e. the early joined ranks only iterate 10 times less than the ones that iterate for the full N, instead of an all-or-nothing in the above tests.


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
4) Has not been thoroughly tested with DDP + RPC. We plan to add support for this in follow up PRs. 

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Aug 26, 2020
Pull Request resolved: #42577

Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy

We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)

#### Performance implications

###### Trunk vs PR patched, 32 GPUs, batch size = 32
P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s

###### join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs
P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s

###### join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs
P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s

###### join(enable=True) with uneven inputs ((offset = 2000)), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s

The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.




#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
ghstack-source-id: 110710165

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!
facebook-github-bot pushed a commit that referenced this pull request Aug 26, 2020
… in (#43468)

Summary:
Pull Request resolved: #43468

Closes #41378.
#41973 enhanced the skip decorators to
report the right no. of GPUs required, but this information was not passed to
the main process where the message is actually displayed. This PR uses a
`multiprocessing.Manager()` so that the dictionary modification is reflected
correctly in the main process.
ghstack-source-id: 110684228

Test Plan:
With this diff, we can run a test in such as in #42577 that requires 4 GPUs on a 2 GPU machine, and we get the expected message:

```
test_ddp_uneven_inputs_replicated_error (test_distributed.TestDistBackend) ... skipped 'Need at least 4 CUDA devices'
```

Reviewed By: mrshenli

Differential Revision: D23285790

fbshipit-source-id: ac32456ef3d0b1d8f1337a24dba9f342c736ca18
Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager that is owned by `class DistributedDataParallel` to coordinate the below process.
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
    a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm. 
4) We provide an option of whether we should divide by the initial world_size or effective world_size when some ranks are gone (default to initial world_size). If dividing by effective world_size, we adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy and broadcasts its parameters.

We also make the following smaller changes to support the above:
- Add a `rank` argument to `_distributed_broadcast_coalesced` to specify which rank should do the broadcast, instead of forcing rank 0. This is needed because we cannot select rank 0 arbitrarily in the join-mode.
- Add a helper function to `DistributedDataParallel` which will have all processes agree on a common rank based on some condition. This common rank is then used for broadcasting final model params and module buffers throughout training.
- Expose several helper methods on `Reducer` such as getters for the `Reducer`s `Bucket`s and the ability to invoke `rebuildBuckets()` from Python, to support "shadowing" collective calls in join mode.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test gradient division by the effective world_size (no. of unjoined processes) and the static world_size
- [x] Test that exceptions during training are correctly propagated by the context manager
- [x] Test expected behavior when the manager is disabled with `enable=False` (for debug purposes)
- [x] Test expected behavior when > 1 process joins early (at different iterations)
- [x] Test model equivalence to local training when used with join API.

#### How to run the tests
The main test can be run with `touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" python test/distributed/test_distributed.py -v TestDistBackend.test_ddp_uneven_inputs`

#### Performance implications

###### Trunk vs PR patched, 32 GPUs, batch size = 32
P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s

###### join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs
P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s

###### join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs
P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s

The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.

###### join(enable=True) with uneven inputs (offset = 10), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.120   265/s
P50 backwards only batch latency & total QPS: 0.087 367/s vs 0.087   367/s
This means that there is only a difference of 10 in the uneven inputs, i.e. the early joined ranks only iterate 10 times less than the ones that iterate for the full N, instead of an all-or-nothing in the above tests.


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
4) Has not been thoroughly tested with DDP + RPC. We plan to add support for this in follow up PRs. 

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!

[ghstack-poisoned]
Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager that is owned by `class DistributedDataParallel` to coordinate the below process.
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
    a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm. 
4) We provide an option of whether we should divide by the initial world_size or effective world_size when some ranks are gone (default to initial world_size). If dividing by effective world_size, we adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy and broadcasts its parameters.

We also make the following smaller changes to support the above:
- Add a `rank` argument to `_distributed_broadcast_coalesced` to specify which rank should do the broadcast, instead of forcing rank 0. This is needed because we cannot select rank 0 arbitrarily in the join-mode.
- Add a helper function to `DistributedDataParallel` which will have all processes agree on a common rank based on some condition. This common rank is then used for broadcasting final model params and module buffers throughout training.
- Expose several helper methods on `Reducer` such as getters for the `Reducer`s `Bucket`s and the ability to invoke `rebuildBuckets()` from Python, to support "shadowing" collective calls in join mode.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test gradient division by the effective world_size (no. of unjoined processes) and the static world_size
- [x] Test that exceptions during training are correctly propagated by the context manager
- [x] Test expected behavior when the manager is disabled with `enable=False` (for debug purposes)
- [x] Test expected behavior when > 1 process joins early (at different iterations)
- [x] Test model equivalence to local training when used with join API.

#### How to run the tests
The main test can be run with `touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" python test/distributed/test_distributed.py -v TestDistBackend.test_ddp_uneven_inputs`

#### Performance implications

###### Trunk vs PR patched, 32 GPUs, batch size = 32
P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s

###### join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs
P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s

###### join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs
P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s

The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.

###### join(enable=True) with uneven inputs (offset = 10), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.120   265/s
P50 backwards only batch latency & total QPS: 0.087 367/s vs 0.087   367/s
This means that there is only a difference of 10 in the uneven inputs, i.e. the early joined ranks only iterate 10 times less than the ones that iterate for the full N, instead of an all-or-nothing in the above tests.


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
4) Has not been thoroughly tested with DDP + RPC. We plan to add support for this in follow up PRs. 

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Aug 26, 2020
Pull Request resolved: #42577

Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy

We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)

#### Performance implications

###### Trunk vs PR patched, 32 GPUs, batch size = 32
P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s

###### join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs
P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s

###### join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs
P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s

###### join(enable=True) with uneven inputs ((offset = 2000)), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s

The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.




#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
ghstack-source-id: 110773782

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!
Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager that is owned by `class DistributedDataParallel` to coordinate the below process.
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
    a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm. 
4) We provide an option of whether we should divide by the initial world_size or effective world_size when some ranks are gone (default to initial world_size). If dividing by effective world_size, we adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy and broadcasts its parameters.

We also make the following smaller changes to support the above:
- Add a `rank` argument to `_distributed_broadcast_coalesced` to specify which rank should do the broadcast, instead of forcing rank 0. This is needed because we cannot select rank 0 arbitrarily in the join-mode.
- Add a helper function to `DistributedDataParallel` which will have all processes agree on a common rank based on some condition. This common rank is then used for broadcasting final model params and module buffers throughout training.
- Expose several helper methods on `Reducer` such as getters for the `Reducer`s `Bucket`s and the ability to invoke `rebuildBuckets()` from Python, to support "shadowing" collective calls in join mode.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test gradient division by the effective world_size (no. of unjoined processes) and the static world_size
- [x] Test that exceptions during training are correctly propagated by the context manager
- [x] Test expected behavior when the manager is disabled with `enable=False` (for debug purposes)
- [x] Test expected behavior when > 1 process joins early (at different iterations)
- [x] Test model equivalence to local training when used with join API.

#### How to run the tests
The main test can be run with `touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" python test/distributed/test_distributed.py -v TestDistBackend.test_ddp_uneven_inputs`

#### Performance implications

###### Trunk vs PR patched, 32 GPUs, batch size = 32
P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s

###### join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs
P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s

###### join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs
P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s

The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.

###### join(enable=True) with uneven inputs (offset = 10), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.120   265/s
P50 backwards only batch latency & total QPS: 0.087 367/s vs 0.087   367/s
This means that there is only a difference of 10 in the uneven inputs, i.e. the early joined ranks only iterate 10 times less than the ones that iterate for the full N, instead of an all-or-nothing in the above tests.


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
4) Has not been thoroughly tested with DDP + RPC. We plan to add support for this in follow up PRs. 

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Aug 27, 2020
Pull Request resolved: #42577

Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy

We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)

#### Performance implications

###### Trunk vs PR patched, 32 GPUs, batch size = 32
P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s

###### join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs
P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s

###### join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs
P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s

###### join(enable=True) with uneven inputs ((offset = 2000)), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s

The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.




#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
ghstack-source-id: 110814064

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!
Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager that is owned by `class DistributedDataParallel` to coordinate the below process.
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
    a) There are a number of scenarios where in the backward pass, we have more than an allreduce for all tensors. For example, unused param detection and bucket rebuilding requires collective comm. 
4) We provide an option of whether we should divide by the initial world_size or effective world_size when some ranks are gone (default to initial world_size). If dividing by effective world_size, we adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy and broadcasts its parameters.

We also make the following smaller changes to support the above:
- Add a `rank` argument to `_distributed_broadcast_coalesced` to specify which rank should do the broadcast, instead of forcing rank 0. This is needed because we cannot select rank 0 arbitrarily in the join-mode.
- Add a helper function to `DistributedDataParallel` which will have all processes agree on a common rank based on some condition. This common rank is then used for broadcasting final model params and module buffers throughout training.
- Expose several helper methods on `Reducer` such as getters for the `Reducer`s `Bucket`s and the ability to invoke `rebuildBuckets()` from Python, to support "shadowing" collective calls in join mode.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test gradient division by the effective world_size (no. of unjoined processes) and the static world_size
- [x] Test that exceptions during training are correctly propagated by the context manager
- [x] Test expected behavior when the manager is disabled with `enable=False` (for debug purposes)
- [x] Test expected behavior when > 1 process joins early (at different iterations)
- [x] Test model equivalence to local training when used with join API.

#### How to run the tests
The main test can be run with `touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" python test/distributed/test_distributed.py -v TestDistBackend.test_ddp_uneven_inputs`

#### Performance implications

###### Trunk vs PR patched, 32 GPUs, batch size = 32
P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s

###### join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs
P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s

###### join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs
P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s

The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.

###### join(enable=True) with uneven inputs (offset = 10), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.120   265/s
P50 backwards only batch latency & total QPS: 0.087 367/s vs 0.087   367/s
This means that there is only a difference of 10 in the uneven inputs, i.e. the early joined ranks only iterate 10 times less than the ones that iterate for the full N, instead of an all-or-nothing in the above tests.


#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
4) Has not been thoroughly tested with DDP + RPC. We plan to add support for this in follow up PRs. 

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Aug 31, 2020
Pull Request resolved: #42577

Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:

#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy

We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes.

#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)

#### Performance implications

###### Trunk vs PR patched, 32 GPUs, batch size = 32
P50, forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 369/s vs 0.087 368/s

###### join(enable=True) vs without join, 32 GPUs, batch size = 32, even inputs
P50, forward + backward + optimizer batch latency & total QPS: 0.120 265/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.088 364/s vs 0.087 368/s

###### join(enable=False) vs without join, 32 GPUs, batch size = 32, even inputs
P50 forward + backward + optimizer batch latency & total QPS: 0.121 264/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.087 368/s vs 0.087 368/s

###### join(enable=True) with uneven inputs (offset = 2000), 32 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.183 174/s vs 0.121 264/s
P50 backwards only batch latency & total QPS: 0.150 213/s vs 0.087 368/s

###### join(enable=True) with uneven inputs ((offset = 2000)), 8 GPUs, batch size = 32
P50 forward + backward + optimizer batch latency & total QPS: 0.104 308/s vs 0.104 308/s
P50 backwards only batch latency & total QPS: 0.070 454/s vs 0.070 459/s

The 2 above uneven inputs benchmark was conducted 32 GPUs and 4 GPUs immediately depleting their inputs and entering "join" mode (i.e. not iterating at all), while the other 28 iterating as normal. It looks like there is a pretty significant perf hit for this case when there are uneven inputs and multi-node training. Strangely, when there is a single node (8 GPUs), this does not reproduce.




#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
ghstack-source-id: 111033819

Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!
@codecov
Copy link

codecov bot commented Aug 31, 2020

Codecov Report

Merging #42577 into gh/rohan-varma/152/base will decrease coverage by 0.07%.
The diff coverage is 26.25%.

Impacted file tree graph

@@                     Coverage Diff                     @@
##           gh/rohan-varma/152/base   #42577      +/-   ##
===========================================================
- Coverage                    69.32%   69.24%   -0.08%     
===========================================================
  Files                          378      378              
  Lines                        46749    46824      +75     
===========================================================
+ Hits                         32408    32424      +16     
- Misses                       14341    14400      +59     
Impacted Files Coverage Δ
torch/nn/parallel/distributed.py 42.53% <25.31%> (-8.24%) ⬇️
torch/testing/_internal/common_distributed.py 65.86% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3aeb70d...adc6bf2. Read the comment docs.

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 4e4626a.

@facebook-github-bot facebook-github-bot deleted the gh/rohan-varma/152/head branch September 4, 2020 14:17
rohan-varma added a commit that referenced this pull request Sep 23, 2020
… in DDP training"


This request came up in feature review for DDP uneven inputs, so this PR adds a warning when there is much higher than expected amount of
discrepancy of inputs across different processes when running with uneven
inputs. This is because a skew in the thousands can reduce performance a
nontrivial amount as shown in benchmarks in #42577, and it was proposed to add this
warning as a result. Tested by running the tests so the threshold is hit and
observing the output.

Differential Revision: [D23719270](https://our.internmc.facebook.com/intern/diff/D23719270/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Sep 24, 2020
… in DDP training"


This request came up in feature review for DDP uneven inputs, so this PR adds a warning when there is much higher than expected amount of
discrepancy of inputs across different processes when running with uneven
inputs. This is because a skew in the thousands can reduce performance a
nontrivial amount as shown in benchmarks in #42577, and it was proposed to add this
warning as a result. Tested by running the tests so the threshold is hit and
observing the output.

Differential Revision: [D23719270](https://our.internmc.facebook.com/intern/diff/D23719270/)

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants