-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[RFC] grouping tensors in C++ for fused adam(w) #94344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94344
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit f0b08c8: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Nice! |
|
with huggingface/transformers example of gpt2-medium (script) on A100x4 master (2.0.0a0+git1767026) this pr machine: DGXStation (A100x4, AMD EPYC7742 x 64) Command: In the first run, this check Line 60 in f9cc12e
errored out so I removed it and things worked. So I'll remove. |
|
t5-large master this pr command: |
8ff00e0 to
cff18cb
Compare
|
Thanks for looking into this! Is your long term goal here to replace the python grouping util with a generalized C++ util or is this PR just for fusedAdamW? |
The former. that said, I'm not sure how to avoid multiple groupings for |
|
Yea, I agree with the concern if we were to expand foreach grouping to C++ as a generic util. I have a more general concern about adding grouping for general foreach ops (like add, mul, etc)--will this add significant perf cost for foreach calls where the tensors were already properly grouped? What is the perf difference if we added grouping for every foreach op vs the existing checks we do today? Do you have some benchmarks comparing the existing checks logic vs the grouping logic? If the difference is significant, would a potential solution be to allow users to pass in a skip_grouping flag that says "I'm sure I don't need grouping, just run + if it fails it's my fault"? This would require an API change though...which....is likely BC breaking. cc'ing @ngimel as well |
I had a similar thought: adding a flag to each foreach function and/or expose from cpp to python some function which takes list of pytorch/torch/cuda/amp/grad_scaler.py Line 195 in c16b291
|
|
@pytorchbot rebase |
|
@pytorchbot successfully started a rebase job. Check the current status here |
|
Successfully rebased |
cff18cb to
93384db
Compare
|
I would agree that having a separate C++ util exposed that could be optionally called from relevant endpoints would be better than changing the API haha. Do you have the numbers for how long the grouping would take with the gpt2-medium params? |
…95847) Fixes #95781. The cause seems to be that the current implementation doesn't correctly pass `found_inf` when `grad_scale` is `None`. Therefore parameters can get mistakenly updated by gradients whose some elements are invalid, i.e. nan or inf. Related #94060 I forgot about this wrong handling after #94344 Pull Request resolved: #95847 Approved by: https://github.com/janeyx99
…ytorch#95847) Fixes pytorch#95781. The cause seems to be that the current implementation doesn't correctly pass `found_inf` when `grad_scale` is `None`. Therefore parameters can get mistakenly updated by gradients whose some elements are invalid, i.e. nan or inf. Related pytorch#94060 I forgot about this wrong handling after pytorch#94344 Pull Request resolved: pytorch#95847 Approved by: https://github.com/janeyx99
93384db to
39681b9
Compare
|
some more numbers (top: this pr, bottom: master) "Self" numbers look good. FP32 - the same gpt2 script aboveFP16 AMP (autocast & GradScaler) |
…ytorch#95847) Fixes pytorch#95781. The cause seems to be that the current implementation doesn't correctly pass `found_inf` when `grad_scale` is `None`. Therefore parameters can get mistakenly updated by gradients whose some elements are invalid, i.e. nan or inf. Related pytorch#94060 I forgot about this wrong handling after pytorch#94344 Pull Request resolved: pytorch#95847 Approved by: https://github.com/janeyx99
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is only_device_check simply is_step_tensor? nit on a more indicative name if so
|
The numbers do look good! I have some questions:
|
39681b9 to
0c82432
Compare
yes. IIUC the difference mainly come from C++ or Python including the setting device guard and handling of map/dict of device and Tensor for grad scaler & found inf.
I've been played with Line 49 in ae3316c
I guess having the flag to tell whether or not momentum buffers need initialization will work well, e.g. https://github.com/NVIDIA/apex/blob/7150e20cc3adef34e3f36261a1070fc0882f16a7/apex/optimizers/fused_sgd.py#L121-L136, which I think backward breaking though |
…ytorch#95847) Fixes pytorch#95781. The cause seems to be that the current implementation doesn't correctly pass `found_inf` when `grad_scale` is `None`. Therefore parameters can get mistakenly updated by gradients whose some elements are invalid, i.e. nan or inf. Related pytorch#94060 I forgot about this wrong handling after pytorch#94344 Pull Request resolved: pytorch#95847 Approved by: https://github.com/janeyx99
aten/src/ATen/native/ForeachUtils.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this logic just be moved before the for loop so it doesn't keep checking with count?
aten/src/ATen/native/ForeachUtils.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To my understanding, this function would be only used by the optimizers, yes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will you be switching this to call the generic foreach util once that's ready? I think that would be ideal (vs having a copy in the fused_adam_utils) and I'd be curious if the perf results would be different.
Ah, so this means the wins won't be as prominent for the other optimizers/foreach ops, but should be a general win nonetheless.
Is the problem here that the momentum buffer list may be all Nones initially? I think the way I dealt with it before is just including the with_indices flag and having sgd use the indices to update the original buffer. Is that insufficient? |
…95847) (#97885) Fixes #95781. The cause seems to be that the current implementation doesn't correctly pass `found_inf` when `grad_scale` is `None`. Therefore parameters can get mistakenly updated by gradients whose some elements are invalid, i.e. nan or inf. Related #94060 I forgot about this wrong handling after #94344 Pull Request resolved: #95847 Approved by: https://github.com/janeyx99
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
0c82432 to
7f92a8b
Compare
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
|
hey! why did you close the pr? sorry if you had been waiting on a review—i did not think it was ready before and would be happy to review |
|
@janeyx99 I found this branch a bit too messed up so refactored it into 100007 |
|
Thanks for the update--I'll look there! |
rel: #94344 Pull Request resolved: #100007 Approved by: https://github.com/janeyx99
rel: #94344 Pull Request resolved: #100007 Approved by: https://github.com/janeyx99 (cherry picked from commit 74b7a6c)
This duplicates the core functionality of https://github.com/pytorch/pytorch/blob/master/torch/utils/_foreach_utils.py#L21 in C++.
The changes are
device_guardand callingdevice_checkin each fused kernelIdeally there should be an equivalent in C++ which we can use for foreach optimizer and each foreach function so that we can remove the for-loop in Python. Note that once having each foreach function takes care of this grouping, the foreach optimizers could redundantly do the grouping; so still iterating over the lists of grouped tensors could make sense in foreach optimizers.
related: #58833, #89591 (comment)
cc @ptrblck @ngimel