-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[optim] skip .item calls in all optimizers when compiling with dynamo #88173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88173
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 5a05a87: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
after this PR lands, I suspect we can remove the |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this make it capturable? These are still CPU Tensors.
|
Also this will trigger #74424 again. So not sure if we can do this. |
I dont think it will trigger that issue with XLA because we've removed all |
Capturable (i believe) is whether dynamo can get a whole graph of it -- not that it should all be on GPU. |
Capturable is used in the context of cudagraph: "whether this instance is safe to capture in a CUDA graph.". |
Ho, interesting. In that case, I have to check |
|
I think there are two things here:
We could consider checking during the step itself if we're using cudagraph but that would mean that we might have to move the step then (which is a sync point so not good). Also I'm not sure how dynamo will handle that since it wants to use cudagraph, I guess it already breaks on any mix of CPU/GPU compute? And so this change will not really remove the graph breaks within dynamo? |
|
@albanD One question: if the step is on GPU, this would be bad without cudagraphs because we're launching a kernel to do step + 1 essentially right? I think in the dynamo case this won't matter, because this op will get fused into a larger kernel anyway so could we make |
Correct.
We can't make it the default no. For GPU-like backends, having these small ops on the device is prohibitively slow #74424 |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me!
I guess XLA is not GPU-like backends as it is lazy?
Seems the overhead mainly coming from moving those tensors back to CPU, which is exactly the issue this PR fixes. cc @JackCaoG; I believe it would be beneficial to make capturable=True by default for all fusion-capable backends, e.g. XLA or CUDA graphs. |
|
Generally speaking removing |
in the dynamoc -> xla bridge for training, we are not running optimizer actually ( https://github.com/pytorch/pytorch/blob/master/benchmarks/dynamo/common.py#L894 ). But looks like people are adding them. |
If See https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py#L366-L393 |
|
Do we have the sanity check benchmark as discussed above to make sure nothing is obviously broken in terms of speed due to these changes? |
|
Update: This causes a pretty bad perf regression in both single and multi tensor mode for eager. As a workaround I will use cc @albanD |
|
Update: The current version matches the eager perf of the old version per Alban's request. Should be good to go, just need a sign off. Updated nums Tests issues are related to the dynamo import in optim, should be fixed now I think |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is acceptable but really not great in term of readability.
| param.addcdiv_(grad, denom) | ||
| param.addcdiv_(exp_avg, denom) | ||
| else: | ||
| mu_product_next = _get_value(mu_product) * mu * mu_next |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you do the _get_value on mu_product here and not on mu_product_next like in the old code?
|
I don't think symint is ready for prime time since there are still some lingering bugs. Once dynamic shapes is on by default I'm fine with removing this. Yes we usually don't want to modify code to compile with dynamo. In most other cases we don't care about eager perf as much though. Like in this case if we didn't care about eager perf I'd just remove the item calls and call it a day since dynamo provides good speed up. It ends up being a trade off off of is this worth modifying the compiler and adding x feature vs should I just remove x from the code if it isn't really necessary. With user code I lean towards the former while in code we own it's more balanced. |
|
@pytorchbot merge -f "Unrelated test failure" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
I would challenge that actually. People that write library code (nn, vision, audio, etc) need to support both. So this is a major use case! |
@mlazos: skips
item()calls if compiling with dynamo, by defining a helper function_get_valuewhich either returns the result of.item()or the scalar cpu tensor if compiling with dynamo. This was done because removingitem()calls significantly regresses eager perf. Additionally,_dispatch_sqrtcalls the appropriate sqrt function (math.sqrt, or torch.sqrt).Fixes pytorch/torchdynamo#1083
This PR will no longer be needed once symint support is default.
This PR closes all remaining graph breaks in the optimizers (!!)
cc @mlazos @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire