-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[WIP] Rewrote adam optimizer with foreach APIs #43507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 1402b87 (more details on the Dr. CI page):
🕵️ 9 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
[ghstack-poisoned]
[ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) [ghstack-poisoned]
ngimel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This generally looks good, but it reveals that we need a few more (TensorList, ScalarList) operations, because now there are still a few loops over parameters/grads.
| for p in group['params']: | ||
| if p.grad is not None: | ||
| params_with_grad.append(p) | ||
| grads.append(p.grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could check for sparse gradients here, not in a separate loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
torch/optim/multi_tensor/adam.py
Outdated
| grads.append(p.grad) | ||
|
|
||
| for p in params_with_grad: | ||
| for g in grads: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove loop over grads, check them earlier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
torch/optim/multi_tensor/adam.py
Outdated
| bias_correction1 = [1 - beta1 ** state['step'] for state in states] | ||
| bias_correction2 = [1 - beta2 ** state['step'] for state in states] | ||
| if group['weight_decay'] != 0: | ||
| torch._foreach_add_(grads, group['params'], group['weight_decay']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the original optimizer the line is
grad = grad.add(p, alpha=group['weight_decay'])
This is not inplace, and original grad attributes (p.grad) aren't mutated. Here you are mutating p.grad inplace (I'm surprised tests don't catch it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
|
||
| if amsgrad: | ||
| # Maintains the maximum of all 2nd moment running avg. till now | ||
| max_exp_avg_sq = [torch.max(a, b) for a, b in zip(max_exp_avg_sq, exp_avg_sq)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha, so ideally we also need a foreach max, because now it will be a loop?
torch/optim/multi_tensor/adam.py
Outdated
| # Use the max. for normalizing running avg. of gradient | ||
| max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sq) | ||
| bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2] | ||
| max_exp_avg_sq_sqrt = [torch.div(a, b) for a, b in zip(max_exp_avg_sq_sqrt, bias_correction_sqrt)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so this is a loop because we don't have Op(TensorList, ScalarList)? This is unfortunate, looks like we really need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per conversation in slack, we can check if all steps are the same, that will guarantee that all bias_corrections are the same, and we can use for_each here and in other places if it's the case (should be common).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you suggesting a change in the underlying algorithm? What does it mean and what would we happens if all steps are not the same?
| step_size = [group['lr'] / bc for bc in bias_correction1] | ||
|
|
||
| for i in range(len(step_size)): | ||
| params_with_grad[i].addcdiv_(exp_avg[i], denom[i], value=-step_size[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And this is another case where we need op(TensorList, ScalarList), because params_with_grad is a tensorlist, exp_avg is a tensorlist, denom is a TensorList, and value is ScalarList?
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **Current API restrictions** - List can't be empty (will fixed in upcoming PRs). - All tensors in the list must have the same dtype, device and size. **Broadcasting** At this point we don't support broadcasting. **What is 'Fast' and 'Slow' route** In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path. To go the fast route, - All tensors must have strided layout - All tensors must be dense and not have overlapping memory - The resulting tensor type must be the same dtype. - All Tensors must be on the same device. ---------------- **In this PR** - We are introducing new namespace under torch.optim - torch.optim.multi_tensor, where we will have optimizers rewritten with _foreach_* APIs. - Rewriting adam optimizer with _foreach_* APIs [ghstack-poisoned]
| max_exp_avg_sq = [torch.max(a, b) for a, b in zip(max_exp_avg_sq, exp_avg_sq)] | ||
| # Use the max. for normalizing running avg. of gradient | ||
| max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sq) | ||
| bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this is not a for_each_sqrt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in this context, bias_correction2 is a list of scalars. there is no foreach api that supports lists of scalars. im working on those APIs right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This particular case is ok, those are python scalars and this operation is reasonably fast. However, other cases where cuda tensors are interacting with a list of scalars are more problematic. We could get around them for now by checking if all bias_corrections have the same value (which should be a common case).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I gave myself 2 days to fight codegen and make ScalarList a reality. if it works out - i will just add new APIs for _foreach_op(TensorList, ScalarList). And if it will be too complex, will make a workaround and add this to TODO.
|
In case there is a way of avoiding duplicating code:
This mean some operations could be slower with MultiTensor for now though, right? Has this been measured? |
Now that we have the implementation, have you been able to quantify the performance scaling? |
|
|
||
| with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): | ||
| optim.Adam(None, lr=1e-2, weight_decay=-1) | ||
| for optimizer in [optim.Adam, optim_mt.Adam]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: current tests don't guarantee that the algorithm is the same as the other, or even that there's convergence in either case.
@zou3519 @anjali411 -- how are the C++ APIs tested?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C++ Optimizer API logic is tested here: https://github.com/pytorch/pytorch/blob/master/test/cpp/api/optim.cpp#L311. These tests compare the C++ optimizers' results to the Python API optimizers' results prewritten in this file: https://github.com/pytorch/pytorch/blob/master/test/cpp/api/optim_baseline.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so here we'll need the tests comparing optim.Adam results with optim_mt.Adam, in addition to _test_basic_cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, since the contract here is that the MultiTensor implementation is the same as the original.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there also a C++ equivalent with the foreach API for the C++ optimizer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vincentqb, there will be C++ optimizers as well but a bit later. We decided to start with python ones first.
re testing: Is there anything specific you would suggest testing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we'd have a test as mentioned by @anjali411 above that checks that the two implementations give the exact same answer in some case.
Note that if this implementation were directly replacing the current one, the C++ test would also tell us that the implementations are still aligned :) Maybe there's a way of leveraging those tests already there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are tests in apex comparing optimizer implementations, they can be adopted here if c++ tests are hard to use for some reason https://github.com/NVIDIA/apex/blob/master/tests/L0/run_optimizers/test_fused_optimizer.py
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893) **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **Current API restrictions** - List can't be empty (will fixed in upcoming PRs). - All tensors in the list must have the same dtype, device and size. **Broadcasting** At this point we don't support broadcasting. **What is 'Fast' and 'Slow' route** In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path. To go the fast route, - All tensors must have strided layout - All tensors must be dense and not have overlapping memory - The resulting tensor type must be the same dtype. - All Tensors must be on the same device. ---------------- **In this PR** - We are introducing new namespace under torch.optim - torch.optim.multi_tensor, where we will have optimizers rewritten with _foreach_* APIs. - Rewriting adam optimizer with _foreach_* APIs [ghstack-poisoned]
|
this is wonderful! thank you! I don't see any user docs discussing this improvement. What is the implication for the user? Should we switch to use Thank you. |
|
Hi, @stas00 Answering your question, you can try using the optimizers from |
|
Thank you very much for the clarification and the stack, @izdeby! So basically we have an option to deploy these early for those who need the speed up sooner, but otherwise there is nothing to be done. Excellent! |
|
If I want to implement global operators like grad_clip to reduce kernel launch by myself, may I use multi_tensor to do it or PyTorch just provide similar interfaces? |
|
Pytorch's |
Thank you so much. I think multi_tensor is what I'm looking for. |
Stack from ghstack:
Differential Revision: D23331893
Motivation
GitHub issue
Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start.
As an example, we should be looking at NVIDIAs Apex.
In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer.
Current API restrictions
Broadcasting
At this point we don't support broadcasting.
What is 'Fast' and 'Slow' route
In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path.
To go the fast route,
In this PR