Skip to content

Conversation

@Xia-Weiwen
Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen commented Apr 27, 2022

Add Linear-(BN)-LeakyReLU fusion for FX quantization. It is enabled for ONEDNN backend only through backend_config.

import torch
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.backend_config import get_onednn_backend_config

qengine = 'onednn'
torch.backends.quantized.engine = qengine
qconfig_mapping = get_default_qconfig_mapping(qengine)
prepared_model = prepare_fx(model_fp32, qconfig_mapping, \
                            example_inputs=x, backend_config=get_onednn_backend_config())
quantized_model = convert_fx(prepared_model, backend_config=get_onednn_backend_config())

For FBGEMM and QNNPACK, FX won't fuse the pattern. If users set fbgemm/qnnpack as the backend while using onednn's backend_config to do quantization, an error is thrown when they run the quantized model.

Tests were added to make sure linear - bn - leaky_relu is not fused by default for fbgemm or qnnpack.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Apr 27, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit afbd3fd595 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@Xia-Weiwen Xia-Weiwen force-pushed the linear_bn_leaky_relu_fusion branch 2 times, most recently from 9c87b13 to 5ec2390 Compare April 27, 2022 03:10
@vadimkantorov
Copy link
Contributor

vadimkantorov commented Apr 27, 2022

Related discussions for fusion/inplace for training time, non-quantized regime: pytorch/vision#4851 (comment), #26288, #23756, pytorch/vision#4851

Given that BN + LeakyReLu+LeakyDropout can be made invertible, Conv+BN (or other normalization)+LeakyRelu+LeakyDropout can be fused or at least made into invertible modules and use no extra memory for storing activations (if a special invertible Sequential container is implemented that allows to explicitly prohibit hooks within this fused invertible container)

@Xia-Weiwen Xia-Weiwen force-pushed the linear_bn_leaky_relu_fusion branch 23 times, most recently from fdccc19 to ae02ba6 Compare May 26, 2022 09:01
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Oct 4, 2022

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: Xia-Weiwen / name: Xia Weiwen (75da6cccd543ce3b6b86491386322579ff8d234d, 7d9b8f5e4f109b0599b1e5ea7e263ba26d715c51, d6c4f35a10c9d371dcc33aebc5c028aca6c35e71, 1f9e2f34c6498754c5e542e47768b7c44b520778, 4a151a474e75593892b5c4ade550c3e305617875, 024d698315909ab36f5835e4397deaac3ec03372, 60404c54e2cccaee07c19ef5e95555a8b11a6ad4, 35b9c26083a583446da216f4a474ca649b33829b)

@Xia-Weiwen Xia-Weiwen force-pushed the linear_bn_leaky_relu_fusion branch 2 times, most recently from 672dcbc to 09eda28 Compare October 8, 2022 09:42
@Xia-Weiwen
Copy link
Collaborator Author

Hi, @kimishpatel, @vkuzo. Thanks for the comments. Since this PR is not updated for a while, the description was not correct. I have synced with @jerryzh168 and now the fusion is for onednn backend only. FX won't fuse linear + leaky_relu for fbgemm or qnnpack. If users manually call the fused op while using fbgemm/qnnpack, an error is thrown. How does that sound to you?

Sounds reasonable, can we add a test that linear + leaky_relu is not fused by default for fbgemm and qnnpack?

Added

@Xia-Weiwen Xia-Weiwen marked this pull request as ready for review October 8, 2022 10:16
@Xia-Weiwen Xia-Weiwen force-pushed the linear_bn_leaky_relu_fusion branch from 8a7a399 to 1ad8e53 Compare October 8, 2022 12:25
@Xia-Weiwen Xia-Weiwen force-pushed the linear_bn_leaky_relu_fusion branch from 1ad8e53 to 7b89273 Compare October 20, 2022 10:27
@Xia-Weiwen
Copy link
Collaborator Author

Hi @jerryzh168 Please review. Thanks.

@Xia-Weiwen Xia-Weiwen requested a review from jgong5 October 21, 2022 07:02
# | DTYPE CONFIGS |
# ===================

# Based on fbgemm's configs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to copy all the fbgemm configs here? onednn doesn't support all these configs, e.g., it doesn't support fp16 and quint4x2. How about only adding those onednn supports?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I will fix these later.

.set_dtype_configs(linear_dtype_configs))

'''
# (2) Linear + tanh
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add it in a separate PR?

template <PostOps post_op>
at::Tensor PackedLinearWeightsOnednn::apply_impl(
at::Tensor input,
double negative_slope,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest not to add the post-op specific parameter into the args of a general function. Consider to do something like below:
https://github.com/pytorch/pytorch/pull/86583/files#diff-c24c006375f92e12dd5b062f143b62d41f347477b0efb7c25e3824ac41f0ad7eR184

Comment on lines +3817 to +3822
@given(batch_size=st.integers(1, 4),
input_channels=st.integers(16, 32),
output_channels=st.integers(4, 8),
use_bias=st.booleans(),
use_multi_dim_input=st.booleans(),
use_channelwise=st.booleans())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use this locally to make sure the implementation is good, but ideally we don't want to commit them to master, hypothesis tests has caused many issues in the past since it's not deterministic and always tries to find corner cases that the test fail, and typically it can find some since quantized kernels has rounding, clamping etc. which makes it not very numerically stable.

dtype = torch.quint8
negative_slope = 0.01
# for onednn backend only
with override_quantized_engine('onednn'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we still planning to use onednn, or do we want to switch to use x86 now?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for point, @jerryzh168 . Ideally, we'd enable these new fusions on x86 directly. But the problem now is that the perf with onednn kernel cannot always be better than fbgemm even with post-op fusion. Therefore, our plan is to enable these new fusions on onednn backend first (which can benefit users in some cases but not all) while perf of onednn library being improved. Then, these new fusions can be moved to x86 when onednn kernel plus fusion can always bring better perf. Does this plan make sense to you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sounds good to me, thanks for clarification

@@ -0,0 +1,253 @@
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to have a onednn backend? or just x86 backend?

@jerryzh168
Copy link
Contributor

this PR is really large, I feel it might make sense to break this to
1). linear-bn-leakyrelu fusion op/kernel implementation and tests
2). defining backend config
3). test support in the quantization flow

@Xia-Weiwen
Copy link
Collaborator Author

this PR is really large, I feel it might make sense to break this to 1). linear-bn-leakyrelu fusion op/kernel implementation and tests 2). defining backend config 3). test support in the quantization flow

Hi @jerryzh168 Thanks for your comments. As explained by Jiong and according to our previous discussion on Slack, we will continue improving onednn backend. I will close this PR later and split it into smaller ones.

@Xia-Weiwen Xia-Weiwen closed this Nov 4, 2022
@Xia-Weiwen Xia-Weiwen deleted the linear_bn_leaky_relu_fusion branch November 13, 2024 06:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed intel This tag is for PR from Intel module: fx module: nn Related to torch.nn open source release notes: quantization release notes category Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants