-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Make AdamW, NAdam & RAdam differentiable
#86183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86183
Note: Links to docs will display an error until the docs builds have been completed. ✅ No Failures, 1 PendingAs of commit dc2140d: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/optim/nadam.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is item needed here in fact?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah .. addcdiv_ requires value to be a python scalar. I guess we could unify both execution paths to avoid device synchronization but not sure. cc @albanD (who also was concerned about this)
|
Maybe worth introducing an idiomatic utility clone_if(flag) lambda ( |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait is this included in the other PR? :p
e30bc50 to
83fe75e
Compare
|
separated all the PRs! |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me !
83fe75e to
c19e5fc
Compare
Any thoughts about clone_if idiom? Or are explicit different if-statement paths better? |
I'm not sure? |
1e4ec86 to
8484ad3
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: The following mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
8484ad3 to
b1f8705
Compare
well, i meant that instead of |
Then instead of could be always written: Also, is there a functorch higher-order function wrappers that can act somehow torch.add(tensor, eps, inplace = not differentiable) to choose between torch.add and torch.add_? |
Yes, we go twice over the memory, once to copy it and once to add.
Maybe? That would be a lot of work to add that to all APIs :p |
|
actually, maybe not so many, not that many ops support proper inplace... |
|
Also functorch really doesn't like inplace in general ;) So I don't think they will be happy with adding inplace kwarg haha |
|
Well, maybe not in functorch :) in core |
0148172 to
8f160b2
Compare
|
You can skip the dynamo errors. |
8f160b2 to
3ff8cd2
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Blocked by #86096