Skip to content

Conversation

@ngimel
Copy link
Collaborator

@ngimel ngimel commented Sep 14, 2018

Add dtype argument to softmax/log_softmax functions.
Computing softmax in fp32 precision is necessary for mixed precision training, and converting output of the previous layer into fp32 and then reading it as fp32 in softmax is expensive, memory and perf-wise, this PR allows one to avoid it.
For most input data/dtype combinations, input data is converted to dtype and then softmax is computed. If input data is half type and dtype is fp32, kernels with the corresponding template arguments are called.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't review the kernels. But how about also adding the option to cross entropy loss? :)

@ngimel
Copy link
Collaborator Author

ngimel commented Sep 14, 2018

cross_entropy calls soft_max https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L1645 so it would require couple line python change.

@ssnl
Copy link
Collaborator

ssnl commented Sep 14, 2018

Yep, but you are already changing log_softmax, right?

@ngimel
Copy link
Collaborator Author

ngimel commented Sep 14, 2018

@ssnl, Yes, I can do that.
Test failure is legit,

 ======================================================================
22:32:33 FAIL: test_passing_one_positional_but_not_the_second (__main__.TestCustomOperators)
22:32:33 ----------------------------------------------------------------------
22:32:33 RuntimeError: Found 2 overloads for operator aten::log_softmax! Overloads are not supported from Python.
22:32:33 
22:32:33 During handling of the above exception, another exception occurred:
22:32:33 
22:32:33 Traceback (most recent call last):
22:32:33   File "test_jit.py", line 7796, in test_passing_one_positional_but_not_the_second
22:32:33     torch.ops.aten.log_softmax(torch.ones(5))
22:32:33 AssertionError: "aten::log_softmax\(\) is missing value for argument 'dim'." does not match "Found 2 overloads for operator aten::log_softmax! Overloads are not supported from Python."

but I'm not sure what's the preferred fix should be. FWIW, some operators already have overloads that are not supported from python, e.g.

In [3]: torch.ops.aten.sum(torch.ones(5))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-acd80e61fcec> in <module>()
----> 1 torch.ops.aten.sum(torch.ones(5))

/workspace/ngimel/pytorch_upstream/torch/_ops.py in __getattr__(self, op_name)
     56         # for overloads and raise an exception if there are more than one.
     57         qualified_op_name = '{}::{}'.format(self.name, op_name)
---> 58         op = torch._C._jit_get_operation(qualified_op_name)
     59         # let the script frontend know that op is identical to the builtin op
     60         # with qualified_op_name

RuntimeError: Found 5 overloads for operator aten::sum! Overloads are not supported from Python.

so log_softmax erroring out with similar message is not necessarily a big problem (?) .

@ezyang
Copy link
Contributor

ezyang commented Sep 18, 2018

cc @apaszke @jamesr66a on JIT test

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Sep 18, 2018

@ngimel Looking at this more closely I would advise updating the error message here.

@ngimel
Copy link
Collaborator Author

ngimel commented Sep 18, 2018

@ngimel Looking at this more closely I would advise updating the error message here.

In the jit tests on in the softmax_cpu assert? softmax_cpu should never be called with upconvert = True, and upconvert is not user exposed, if it happend, there's something wrong with the core that user can't fix, hence AT_ASSERTM and not AT_CHECK.

@ngimel
Copy link
Collaborator Author

ngimel commented Sep 20, 2018

Anything I can do to help move this forward? @apaszke @ezyang

test/test_jit.py Outdated

This comment was marked as off-topic.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super happy with the upconvert flag. It doesn't really specify the destination type. Should it be float? Should it be double? The context is probably dependent on the device, and this seems to overfit the CUDA context. Can't we apply a simple modification to our kernels, or simply have a _log_softmax_half_to_float implemented only for CUDA, and dispatch to _log_softmax(...).to(dtype) otherwise?

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

test/test_nn.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

test/test_nn.py Outdated

This comment was marked as off-topic.

@ngimel
Copy link
Collaborator Author

ngimel commented Sep 24, 2018

I'm not super happy with the upconvert flag. It doesn't really specify the destination type. Should it be float? Should it be double? The context is probably dependent on the device, and this seems to overfit the CUDA context. Can't we apply a simple modification to our kernels, or simply have a _log_softmax_half_to_float implemented only for CUDA, and dispatch to _log_softmax(...).to(dtype) otherwise?

upconvert is true only for cuda half inputs with fp32 dtype argument. I could dispatch to _log_softmax_half_to_float in this case, but it would require it's own entries for forward and backward in native_functions and in derivatives, and overall I don't think it would be any prettier.
Adding modification to kernels to support more input type /dtype combinations with a fast path can be done (in fact, for cuda kernels output type can be anything, it's a separate template parameter), but then dispatch will have to be really tricky (right now dispatch defines scalar_t from which I can derive acc_type, but any other combinations of input/output types would require changes to types defined in dispatch, and instantiating a cross-product of kernels with different input/output types, which no one wants.)

@apaszke
Copy link
Contributor

apaszke commented Sep 25, 2018

upconvert is true only for cuda half inputs with fp32 dtype argument

That's precisely the problem. It's a very specific flag, with a very specific meaning, which it not at all implied by its name/function name/function signature.

I don't understand why the dispatch would be a problem. Can't you just declare the derivatives for the top-level native function log_softmax, and have it take full responsibility for providing the derivative no matter which implementation it chooses?

@ngimel
Copy link
Collaborator Author

ngimel commented Sep 25, 2018

If derivatives are declared for log_softmax then backward will have to take care of type conversion which is now delegated to autograd, which is not the end of the world, but will make backward more error-prone, especially if at some point some implicit conversions for other types are added.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apaszke has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Collaborator Author

ngimel commented Oct 12, 2018

Test failures seem unrelated.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 14, 2018
Summary:
Add dtype argument to softmax/log_softmax functions.
Computing softmax in fp32 precision is necessary for mixed precision training, and converting output of the previous layer into fp32 and then reading it as fp32 in softmax is expensive, memory and perf-wise, this PR allows one to avoid it.
For most input data/dtype combinations, input data is converted to dtype and then softmax is computed. If input data is half type and dtype is fp32, kernels with the corresponding template arguments are called.
Pull Request resolved: pytorch/pytorch#11719

Reviewed By: ezyang

Differential Revision: D10175514

Pulled By: zou3519

fbshipit-source-id: 06d285af91a0b659932236d41ad63b787eeed243
@ngimel ngimel deleted the mixed_softmax branch January 16, 2019 19:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants