-
Notifications
You must be signed in to change notification settings - Fork 26.3k
dtype option for softmax #11719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dtype option for softmax #11719
Conversation
ssnl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't review the kernels. But how about also adding the option to cross entropy loss? :)
|
cross_entropy calls soft_max https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L1645 so it would require couple line python change. |
|
Yep, but you are already changing |
|
@ssnl, Yes, I can do that. but I'm not sure what's the preferred fix should be. FWIW, some operators already have overloads that are not supported from python, e.g. so log_softmax erroring out with similar message is not necessarily a big problem (?) . |
|
cc @apaszke @jamesr66a on JIT test |
aten/src/ATen/native/SoftMax.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@ngimel Looking at this more closely I would advise updating the error message here. |
In the jit tests on in the |
test/test_jit.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
3e039a6 to
94c6400
Compare
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not super happy with the upconvert flag. It doesn't really specify the destination type. Should it be float? Should it be double? The context is probably dependent on the device, and this seems to overfit the CUDA context. Can't we apply a simple modification to our kernels, or simply have a _log_softmax_half_to_float implemented only for CUDA, and dispatch to _log_softmax(...).to(dtype) otherwise?
torch/_torch_docs.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
tools/autograd/derivatives.yaml
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_nn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_nn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
That's precisely the problem. It's a very specific flag, with a very specific meaning, which it not at all implied by its name/function name/function signature. I don't understand why the dispatch would be a problem. Can't you just declare the derivatives for the top-level native function |
|
If derivatives are declared for |
56a9546 to
fe802a7
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apaszke has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Test failures seem unrelated. |
Summary: Add dtype argument to softmax/log_softmax functions. Computing softmax in fp32 precision is necessary for mixed precision training, and converting output of the previous layer into fp32 and then reading it as fp32 in softmax is expensive, memory and perf-wise, this PR allows one to avoid it. For most input data/dtype combinations, input data is converted to dtype and then softmax is computed. If input data is half type and dtype is fp32, kernels with the corresponding template arguments are called. Pull Request resolved: pytorch/pytorch#11719 Reviewed By: ezyang Differential Revision: D10175514 Pulled By: zou3519 fbshipit-source-id: 06d285af91a0b659932236d41ad63b787eeed243
Add dtype argument to softmax/log_softmax functions.
Computing softmax in fp32 precision is necessary for mixed precision training, and converting output of the previous layer into fp32 and then reading it as fp32 in softmax is expensive, memory and perf-wise, this PR allows one to avoid it.
For most input data/dtype combinations, input data is converted to dtype and then softmax is computed. If input data is half type and dtype is fp32, kernels with the corresponding template arguments are called.