Skip to content

Conversation

@sinannasir
Copy link
Contributor

@sinannasir sinannasir commented Jun 22, 2020

Stack from ghstack:

Summary:

  1. In reducer.cpp, we have a new boolean find_unused_param_ and its value is set in Reducer::prepare_for_backward.
    If !find_unused_param_, then it avoids allreduce(local_used_maps_dev_).
  2. Solves issue 38942.
  3. Fixes incorrect find_unused_parameters_ passing like checking outputs.empty() or unused_parameters_.empty().

Test Plan:

  1. Run test/distributed/test_c10d.py and make sure all tests pass.
  2. A new test case test_find_unused_parameters_when_unused_parameters_empty is included. Old reducer.cpp was failing in that unit test because it was checking find_unused_parameters_ by unused_parameters_.empty(). Current reducer.cpp passes this unit test.
  3. Two test cases were failing test_forward_backward_unused_parameters and test_forward_backward_optimizer , because find_unused_parameter_ of their reducer object was not set properly. I fixed that as well.

Tasks: T68705534

Tags: DDP

Differential Revision: D22176231

Summary: Solve issue #38942.
In reducer.cpp, I check whether `find_unused_param` is set to False by !unused_parameters_.empty().
If !unused_parameters_.empty(), then it avoids `allreduce(local_used_maps_dev_)`.

Test Plan: Run `test/distributed/test_c10d.py` and make sure all tests pass.

Reviewers: Pritam Damania

Subscribers: Pritam Damania, Shen Li

Tasks: T68705534

Tags: DDP

[ghstack-poisoned]
sinannasir added a commit that referenced this pull request Jun 22, 2020
Summary: Solve issue #38942.
In reducer.cpp, I check whether `find_unused_param` is set to False by !unused_parameters_.empty().
If !unused_parameters_.empty(), then it avoids `allreduce(local_used_maps_dev_)`.

Test Plan: Run `test/distributed/test_c10d.py` and make sure all tests pass.

Reviewers: Pritam Damania

Subscribers: Pritam Damania, Shen Li

Tasks: T68705534

Tags: DDP

ghstack-source-id: 3da54f3
Pull Request resolved: #40407
@dr-ci
Copy link

dr-ci bot commented Jun 23, 2020

💊 CI failures summary and remediations

As of commit fe145bb (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 2/2 non-CircleCI failure(s)

Extra GitHub checks: 2 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 27 times.

@mrshenli
Copy link
Contributor

And since the unused parameters from different processes can be different, let's also add a test for that case.

…param=False`"


Summary: 
1. Solves issue [38942](#38942).
2. In reducer.cpp, I check whether `find_unused_param` is set to `false` by `!unused_parameters_.empty()`.
If `!unused_parameters_.empty()`, then it avoids `allreduce(local_used_maps_dev_)`.
3. `unused_parameters_()` is always empty, as we set `find_unused_param = false`. Further, there may be additional cases with an empty `unused_parameters_`, but `find_unused_param = true`. Therefore, in addition to the issue [38942](#38942), this diff also avoids `allreduce(local_used_maps_dev_)` for that case.
I think that since we never call `unused_parameters_.clear()` in separate processes, this will not cause any issue and will not lead to DDP comm hang. In fact, it will be more beneficial, because if all parameters are used, then each should already be reduced. 

Test Plan: Run `test/distributed/test_c10d.py` and make sure all tests pass.

Reviewers: Pritam Damania

Subscribers: Pritam Damania, Shen Li

Tasks: T68705534

Tags: DDP

Differential Revision: [D22176231](https://www.internalfb.com/intern/diff/D22176231/)

[ghstack-poisoned]
sinannasir added a commit that referenced this pull request Jun 23, 2020
Pull Request resolved: #40407

1. Solves issue [38942](#38942).
2. In reducer.cpp, we have a new boolean `find_unused_param_` and its value is set in `Reducer::prepare_for_backward`.
If `!find_unused_param_`, then it avoids `allreduce(local_used_maps_dev_)`.

Differential Revision: [D22176231](https://our.internmc.facebook.com/intern/diff/D22176231/)
ghstack-source-id: 106447051
@sinannasir sinannasir requested a review from mrshenli June 23, 2020 20:54
…param=False`"


Summary: 
1. Solves issue [38942](#38942).
2. In reducer.cpp, we have a new boolean `find_unused_param_` and its value is set in `Reducer::prepare_for_backward`.
If `!find_unused_param_`, then it avoids `allreduce(local_used_maps_dev_)`.

Test Plan: Run `test/distributed/test_c10d.py` and make sure all tests pass.

Reviewers: Pritam Damania

Subscribers: Pritam Damania, Shen Li, Yanli Zhao

Tasks: T68705534

Tags: DDP

Differential Revision: [D22176231](https://www.internalfb.com/intern/diff/D22176231/)

[ghstack-poisoned]
@sinannasir sinannasir requested a review from apaszke as a code owner June 24, 2020 02:05
sinannasir added a commit that referenced this pull request Jun 24, 2020
Pull Request resolved: #40407

1. Solves issue [38942](#38942).
2. In reducer.cpp, we have a new boolean `find_unused_param_` and its value is set in `Reducer::prepare_for_backward`.
If `!find_unused_param_`, then it avoids `allreduce(local_used_maps_dev_)`.
ghstack-source-id: 106478491

Differential Revision: [D22176231](https://our.internmc.facebook.com/intern/diff/D22176231/)
Copy link
Contributor

@pritamdamania87 pritamdamania87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And since the unused parameters from different processes can be different, let's also add a test for that case.

@mrshenli Could you elaborate on the test case we need here? Looks like we have unit tests where we test both cases for find_unused_parameters_. Is there something missing in our current tests?

…param=False`"


Summary: 
1. Solves issue [38942](#38942).
2. In reducer.cpp, we have a new boolean `find_unused_param_` and its value is set in `Reducer::prepare_for_backward`.
If `!find_unused_param_`, then it avoids `allreduce(local_used_maps_dev_)`.

Test Plan: Run `test/distributed/test_c10d.py` and make sure all tests pass.

Reviewers: Pritam Damania

Subscribers: Pritam Damania, Shen Li, Yanli Zhao

Tasks: T68705534

Tags: DDP

Differential Revision: [D22176231](https://www.internalfb.com/intern/diff/D22176231/)

[ghstack-poisoned]
sinannasir added a commit that referenced this pull request Jun 24, 2020
Pull Request resolved: #40407

1. Solves issue [38942](#38942).
2. In reducer.cpp, we have a new boolean `find_unused_param_` and its value is set in `Reducer::prepare_for_backward`.
If `!find_unused_param_`, then it avoids `allreduce(local_used_maps_dev_)`.
ghstack-source-id: 106487672

Differential Revision: [D22176231](https://our.internmc.facebook.com/intern/diff/D22176231/)
@mrshenli
Copy link
Contributor

@mrshenli Could you elaborate on the test case we need here? Looks like we have unit tests where we test both cases for find_unused_parameters_. Is there something missing in our current tests?

Sure. The first two attempts in this PR reminds us that we might want to add tests for the following cases:

  1. unused_parameters_.empty() does not imply find_unused_parameters: we can set find_unused_parameters to true, and then let some some process use all parameters but other don't.
  2. outputs in prepare_for_backward does not imply find_unused_parameters: the code today actually implies find_unused_parameters using outputs. However, as we know explicitly passing find_unused_parameters, we can modify the following code accordingly and add tests for it, where the forward pass in some process produces empty outputs.:

// If no outputs are specified, we assume that autograd hooks for ALL
// variables will be called, and we don't have to search the autograd graph
// for presence of these hooks.
if (outputs.empty()) {
return;
}

…param=False`"


Summary: 
1. Solves issue [38942](#38942).
2. In reducer.cpp, we have a new boolean `find_unused_param_` and its value is set in `Reducer::prepare_for_backward`.
If `!find_unused_param_`, then it avoids `allreduce(local_used_maps_dev_)`.

Test Plan: Run `test/distributed/test_c10d.py` and make sure all tests pass.

Reviewers: Pritam Damania

Subscribers: Pritam Damania, Shen Li, Yanli Zhao

Tasks: T68705534

Tags: DDP

Differential Revision: [D22176231](https://www.internalfb.com/intern/diff/D22176231/)

[ghstack-poisoned]
@sinannasir sinannasir requested review from mrshenli, pritamdamania87 and zhaojuanmao and removed request for zhaojuanmao June 24, 2020 15:53
…param=False`"


Summary: 
1. Solves issue [38942](#38942).
2. In reducer.cpp, we have a new boolean `find_unused_param_` and its value is set in `Reducer::prepare_for_backward`.
If `!find_unused_param_`, then it avoids `allreduce(local_used_maps_dev_)`.

Test Plan: Run `test/distributed/test_c10d.py` and make sure all tests pass.

Reviewers: Pritam Damania

Subscribers: Pritam Damania, Shen Li, Yanli Zhao

Tasks: T68705534

Tags: DDP

Differential Revision: [D22176231](https://www.internalfb.com/intern/diff/D22176231/)

[ghstack-poisoned]
@sinannasir
Copy link
Contributor Author

  1. All tests pass and we don't allocate memory for local_used_maps and local_used_maps_dev_ if find_unused_parameters is false.
  2. Fixed declaring global_unused twice when find_unused_parameters_ = true.
  3. Correctly passing find_unused_parameters for test case test_forward_backward_unused_parameters.
  4. Included Note [Skip allreducing local_used_maps_dev_] for reference.

…param=False`"


Summary: 
1. Solves issue [38942](#38942).
2. In reducer.cpp, we have a new boolean `find_unused_param_` and its value is set in `Reducer::prepare_for_backward`.
If `!find_unused_param_`, then it avoids `allreduce(local_used_maps_dev_)`.

Test Plan: Run `test/distributed/test_c10d.py` and make sure all tests pass.

Reviewers: Pritam Damania

Subscribers: Pritam Damania, Shen Li, Yanli Zhao

Tasks: T68705534

Tags: DDP

Differential Revision: [D22176231](https://www.internalfb.com/intern/diff/D22176231/)

[ghstack-poisoned]
sinannasir added a commit that referenced this pull request Jun 24, 2020
Pull Request resolved: #40407

1. In reducer.cpp, we have a new boolean `find_unused_param_` and its value is set in the Reducer's Python API.
2. Solves issue [38942](#38942).
When find_unused_parameters_ is set to false, there is no need to allreduce local_used_maps_dev_, because all parameters will be reduced anyway. Therefore, we can avoid allocating memory for local_used_maps and local_used_maps_dev_ if find_unused_parameters_ is false.

ghstack-source-id: 106525355

Differential Revision: [D22176231](https://our.internmc.facebook.com/intern/diff/D22176231/)
@sinannasir
Copy link
Contributor Author

Avoided some VS Code editorial changes.

@sinannasir
Copy link
Contributor Author

@mrshenli Could you elaborate on the test case we need here? Looks like we have unit tests where we test both cases for find_unused_parameters_. Is there something missing in our current tests?

Sure. The first two attempts in this PR reminds us that we might want to add tests for the following cases:

  1. unused_parameters_.empty() does not imply find_unused_parameters: we can set find_unused_parameters to true, and then let some some process use all parameters but other don't.
  2. outputs in prepare_for_backward does not imply find_unused_parameters: the code today actually implies find_unused_parameters using outputs. However, as we know explicitly passing find_unused_parameters, we can modify the following code accordingly and add tests for it, where the forward pass in some process produces empty outputs.:

// If no outputs are specified, we assume that autograd hooks for ALL
// variables will be called, and we don't have to search the autograd graph
// for presence of these hooks.
if (outputs.empty()) {
return;
}

I am looking into it.

…param=False`"


Summary: 
1. Solves issue [38942](#38942).
2. In reducer.cpp, we have a new boolean `find_unused_param_` and its value is set in `Reducer::prepare_for_backward`.
If `!find_unused_param_`, then it avoids `allreduce(local_used_maps_dev_)`.

Test Plan: Run `test/distributed/test_c10d.py` and make sure all tests pass.

Reviewers: Pritam Damania

Subscribers: Pritam Damania, Shen Li, Yanli Zhao

Tasks: T68705534

Tags: DDP

Differential Revision: [D22176231](https://www.internalfb.com/intern/diff/D22176231/)

[ghstack-poisoned]
sinannasir added a commit that referenced this pull request Jun 25, 2020
Pull Request resolved: #40407

1. In reducer.cpp, we have a new boolean `find_unused_param_` and its value is set in the Reducer's Python API.
2. Solves issue [38942](#38942).
When find_unused_parameters_ is set to false, there is no need to allreduce local_used_maps_dev_, because all parameters will be reduced anyway. Therefore, we can avoid allocating memory for local_used_maps and local_used_maps_dev_ if find_unused_parameters_ is false.

ghstack-source-id: 106644409

Differential Revision: [D22176231](https://our.internmc.facebook.com/intern/diff/D22176231/)
@sinannasir
Copy link
Contributor Author

sinannasir commented Jun 26, 2020

The new changes in the final commit:

  1. A new test case test_find_unused_parameters_when_unused_parameters_empty is included. Today's reducer.cpp is failing in that unit test because it is checking find_unused_parameters_ by using unused_parameters_:
    if (!has_rebuilt_bucket_ && unused_parameters_.empty() &&
    index.replica_index == 0) {
    rebuilt_params_.push_back(
    replicas_[index.replica_index][index.variable_index]);
    rebuilt_param_indices_.push_back(index.variable_index);
    }
    // If there are model parameters that went unused when computing the model
    // output, they won't be part of the autograd graph, and won't receive
    // gradients. These parameters are discovered in the `prepare_for_backward`
    // function and their indexes stored in the `unused_parameters_` vector.
    if (!has_marked_unused_parameters_ && !unused_parameters_.empty()) {
    has_marked_unused_parameters_ = true;
    for (const auto& unused_index : unused_parameters_) {
    mark_variable_ready(unused_index);
    }
    }

The problem is with backward, but the exception is happening in the constructor of the next ddp_model. This is because the broadcast process_group_->broadcast(vec)->wait(); starts to pass weird control_accessor entries after the first problematic backward call. This may have something to do with gloo process group. The test now passes after new modifications.

  1. Writing a test scenario for the following statement is not possible because of If all parameters are unused by forward pass in a process, backward will not work with DDP. #40566

outputs in prepare_for_backward does not imply find_unused_parameters: the code today actually implies find_unused_parameters using outputs.
I still changed the way find_parameters_unused_ is checked in the prepare_for_backward function.

  1. Two test cases were failing test_forward_backward_unused_parameters and test_forward_backward_optimizer , because find_unused_parameter_ of their reducer object was not set properly. I fixed that as well.

Copy link
Contributor

@pritamdamania87 pritamdamania87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there are some merge conflicts in CI you probably need to rebase and resubmit the PR.

Comment on lines +396 to +408
if (!has_rebuilt_bucket_ && !find_unused_parameters_ &&
index.replica_index == 0) {
rebuilt_params_.push_back(
replicas_[index.replica_index][index.variable_index]);
rebuilt_param_indices_.push_back(index.variable_index);
}

// If there are model parameters that went unused when computing the model
// output, they won't be part of the autograd graph, and won't receive
// gradients. These parameters are discovered in the `prepare_for_backward`
// function and their indexes stored in the `unused_parameters_` vector.
if (!has_marked_unused_parameters_ && !unused_parameters_.empty()) {
// If `find_unused_parameters_` is true there may be model parameters that
// went unused when computing the model output, they won't be part of the
// autograd graph, and won't receive gradients. These parameters are discovered
// in the `prepare_for_backward` function and their indexes stored in
// the `unused_parameters_` vector.
if (!has_marked_unused_parameters_ && find_unused_parameters_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes look good to me, although I'll let Shen stamp the PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for letting me know about the merge conflict :) Just rebased and resolved the conflict at init.cpp.

…param=False`"


Summary: 
1. In reducer.cpp, we have a new boolean `find_unused_param_` and its value is set in `Reducer::prepare_for_backward`.
If `!find_unused_param_`, then it avoids `allreduce(local_used_maps_dev_)`.
2. Solves issue [38942](#38942).
3. Fixes incorrect `find_unused_parameters_` passing like checking `outputs.empty()` or `unused_parameters_.empty()`.

Test Plan: 
1. Run `test/distributed/test_c10d.py` and make sure all tests pass.
2. A new test case `test_find_unused_parameters_when_unused_parameters_empty` is included. Old `reducer.cpp` was failing in that unit test because it was checking `find_unused_parameters_` by `unused_parameters_.empty()`. Current `reducer.cpp` passes this unit test.
3. Two test cases were failing `test_forward_backward_unused_parameters` and `test_forward_backward_optimizer` , because `find_unused_parameter_` of their `reducer` object was not set properly. I fixed that as well.

Tasks: T68705534

Tags: DDP

Differential Revision: [D22176231](https://www.internalfb.com/intern/diff/D22176231/)

[ghstack-poisoned]
Copy link
Contributor

@zhaojuanmao zhaojuanmao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall lgtm

…param=False`"


Summary: 
1. In reducer.cpp, we have a new boolean `find_unused_param_` and its value is set in `Reducer::prepare_for_backward`.
If `!find_unused_param_`, then it avoids `allreduce(local_used_maps_dev_)`.
2. Solves issue [38942](#38942).
3. Fixes incorrect `find_unused_parameters_` passing like checking `outputs.empty()` or `unused_parameters_.empty()`.

Test Plan: 
1. Run `test/distributed/test_c10d.py` and make sure all tests pass.
2. A new test case `test_find_unused_parameters_when_unused_parameters_empty` is included. Old `reducer.cpp` was failing in that unit test because it was checking `find_unused_parameters_` by `unused_parameters_.empty()`. Current `reducer.cpp` passes this unit test.
3. Two test cases were failing `test_forward_backward_unused_parameters` and `test_forward_backward_optimizer` , because `find_unused_parameter_` of their `reducer` object was not set properly. I fixed that as well.

Tasks: T68705534

Tags: DDP

Differential Revision: [D22176231](https://www.internalfb.com/intern/diff/D22176231/)

[ghstack-poisoned]
Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for working on this!

}
}

// Note [Skip allreducing local_used_maps_dev]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, any reason for adding this note to the dtor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Pritam's example, the notes were placed after the constructor. Thats why I put it there.
https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/engine.cpp#L100. Example of referring to a note: https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/engine.cpp#L64.

output.mean().backward()

# Now locally unused parameter should have grad updated on all ranks.
[self.assertIsNotNone(t_p.grad) for t_p in model.module.task_parameters()]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I wonder if we should also verify that the grad values are as expected? We can do this by

  1. creating two local models, one's forward uses both t0 and t1, and the other one's forward only uses t1
  2. run forward/backward on them, and then compute the mean grads of these two models
  3. Compare the mean grad with DDP's grad.

I noticed the test above (test_global_local_unused_params_grad) didn't test the accuracy either. So feel free to do this in a followup PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noted this and do a followup soon.

process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size)

# Test on CPU
cpu_model = DistributedDataParallel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: In a follow up PR, It may be worth separating the GPU and CPU test so that (a portion of) this test can be run in CPU-only builds?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I noted that.

@facebook-github-bot facebook-github-bot deleted the gh/sinannasir/1/head branch July 30, 2020 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants