-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[Quant] Remove all the dequant nodes when the ref module has multi input args #90157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Quant] Remove all the dequant nodes when the ref module has multi input args #90157
Conversation
…input [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90157
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0ba4c89: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…as multi input args" **Summary**: When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the `_lower_static_weighted_ref_module` pass. **Test Plan**: It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
…as multi input args" **Summary**: When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the `_lower_static_weighted_ref_module` pass. **Test Plan**: It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
|
@jerryzh168 I think this PR is ready for review, could you help to take a look? |
…as multi input args" **Summary**: When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the `_lower_static_weighted_ref_module` pass. **Test Plan**: It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
…as multi input args" **Summary**: When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the `_lower_static_weighted_ref_module` pass. **Test Plan**: It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
…as multi input args" **Summary**: When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the `_lower_static_weighted_ref_module` pass. **Test Plan**: It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
…input ghstack-source-id: 63fcac0 Pull Request resolved: pytorch#90157
|
Hi @jerryzh168, Could you help to take a look of this fix? |
| dq_node = ref_node.args[0] | ||
| assert(isinstance(dq_node, Node)) | ||
| ref_node.replace_input_with(dq_node, dq_node.args[0]) | ||
| for arg in ref_node.args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this might be a bit hacky, generally we want to identify a specific pattern and just lower that pattern. this function (_lower_static_weighted_ref_module) is assuming that we have "dq -> ref_fp32_module -> q" pattern I think, could you
1). add some checks to this function to make sure this is the case, e.g. check len(ref_node.args) == 1 or something
2). add another lowering function for conv -> add that uses this code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestions, @jerryzh168. Followed up these 2 steps:
- In this PR, I only add the the check of
len(ref_node.args) == 1for this pass (_lower_static_weighted_ref_module) - I have added another lowering pass for
conv -> addnamed as_lower_static_weighted_ref_module_with_two_dq_inputsin this [Quant][FX] Lower QConvAdd2d for onednn backend #91153. Could you take a look of this lowering pass as we talked about here? Does it look good to you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jerryzh168, Could you help to take a look of this fix again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds great, could you update the summary for this PR as well
…as multi input args" **Summary**: When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the `_lower_static_weighted_ref_module` pass. **Test Plan**: It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
…as multi input args" **Summary**: When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the `_lower_static_weighted_ref_module` pass. **Test Plan**: It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
…input ghstack-source-id: 08a8267 Pull Request resolved: pytorch#90157
…as multi input args" **Summary**: When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the `_lower_static_weighted_ref_module` pass. **Test Plan**: It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
jerryzh168
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, please update the summary as well
Updated the summary for this PR. |
|
@pytorchbot rebase |
|
@pytorchbot successfully started a rebase job. Check the current status here |
…as multi input args" **Summary**: When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. We add a check in this PR to ensure this is the case for `_lower_static_weighted_ref_module` pass. **Test Plan**: We only add a check in this PR, there is no new added test case. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
|
Successfully rebased |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Summary:
When converting a ref module into a quant module,
_lower_static_weighted_ref_modulepass assumes theref_nodeonly has 1 input node, and only remove the firstdequantnode. We add a check in this PR to ensure this is the case for_lower_static_weighted_ref_modulepass.Test Plan:
We only add a check in this PR, there is no new added test case.
cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10