-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[Quant][FX] Add Linear-(BN)-LeakyReLU fusion for ONEDNN backend #76424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Quant][FX] Add Linear-(BN)-LeakyReLU fusion for ONEDNN backend #76424
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit afbd3fd595 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
9c87b13 to
5ec2390
Compare
|
Related discussions for fusion/inplace for training time, non-quantized regime: pytorch/vision#4851 (comment), #26288, #23756, pytorch/vision#4851 Given that BN + LeakyReLu+LeakyDropout can be made invertible, Conv+BN (or other normalization)+LeakyRelu+LeakyDropout can be fused or at least made into invertible modules and use no extra memory for storing activations (if a special invertible Sequential container is implemented that allows to explicitly prohibit hooks within this fused invertible container) |
fdccc19 to
ae02ba6
Compare
672dcbc to
09eda28
Compare
Added |
8a7a399 to
1ad8e53
Compare
…ntization.py and simplify implementation of unit tests for linear-leaky_relu
…_relu.py to torch/ao/nn/intrinsic/quantized/modules/linear_relu.py
1ad8e53 to
7b89273
Compare
|
Hi @jerryzh168 Please review. Thanks. |
| # | DTYPE CONFIGS | | ||
| # =================== | ||
|
|
||
| # Based on fbgemm's configs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to copy all the fbgemm configs here? onednn doesn't support all these configs, e.g., it doesn't support fp16 and quint4x2. How about only adding those onednn supports?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I will fix these later.
| .set_dtype_configs(linear_dtype_configs)) | ||
|
|
||
| ''' | ||
| # (2) Linear + tanh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add it in a separate PR?
| template <PostOps post_op> | ||
| at::Tensor PackedLinearWeightsOnednn::apply_impl( | ||
| at::Tensor input, | ||
| double negative_slope, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest not to add the post-op specific parameter into the args of a general function. Consider to do something like below:
https://github.com/pytorch/pytorch/pull/86583/files#diff-c24c006375f92e12dd5b062f143b62d41f347477b0efb7c25e3824ac41f0ad7eR184
| @given(batch_size=st.integers(1, 4), | ||
| input_channels=st.integers(16, 32), | ||
| output_channels=st.integers(4, 8), | ||
| use_bias=st.booleans(), | ||
| use_multi_dim_input=st.booleans(), | ||
| use_channelwise=st.booleans()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can use this locally to make sure the implementation is good, but ideally we don't want to commit them to master, hypothesis tests has caused many issues in the past since it's not deterministic and always tries to find corner cases that the test fail, and typically it can find some since quantized kernels has rounding, clamping etc. which makes it not very numerically stable.
| dtype = torch.quint8 | ||
| negative_slope = 0.01 | ||
| # for onednn backend only | ||
| with override_quantized_engine('onednn'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we still planning to use onednn, or do we want to switch to use x86 now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for point, @jerryzh168 . Ideally, we'd enable these new fusions on x86 directly. But the problem now is that the perf with onednn kernel cannot always be better than fbgemm even with post-op fusion. Therefore, our plan is to enable these new fusions on onednn backend first (which can benefit users in some cases but not all) while perf of onednn library being improved. Then, these new fusions can be moved to x86 when onednn kernel plus fusion can always bring better perf. Does this plan make sense to you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah sounds good to me, thanks for clarification
| @@ -0,0 +1,253 @@ | |||
| import torch | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to have a onednn backend? or just x86 backend?
|
this PR is really large, I feel it might make sense to break this to |
Hi @jerryzh168 Thanks for your comments. As explained by Jiong and according to our previous discussion on Slack, we will continue improving onednn backend. I will close this PR later and split it into smaller ones. |
Add Linear-(BN)-LeakyReLU fusion for FX quantization. It is enabled for ONEDNN backend only through
backend_config.For FBGEMM and QNNPACK, FX won't fuse the pattern. If users set fbgemm/qnnpack as the backend while using onednn's
backend_configto do quantization, an error is thrown when they run the quantized model.Tests were added to make sure
linear - bn - leaky_reluis not fused by default for fbgemm or qnnpack.