Skip to content

Conversation

@leslie-fang-intel
Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel commented Dec 5, 2022

Stack from ghstack (oldest at bottom):

Summary:
When converting a ref module into a quant module, _lower_static_weighted_ref_module pass assumes the ref_node only has 1 input node, and only remove the first dequant node. We add a check in this PR to ensure this is the case for _lower_static_weighted_ref_module pass.

Test Plan:
We only add a check in this PR, there is no new added test case.

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 5, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90157

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0ba4c89:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

leslie-fang-intel added a commit that referenced this pull request Dec 5, 2022
@github-actions github-actions bot added the release notes: quantization release notes category label Dec 5, 2022
@leslie-fang-intel leslie-fang-intel added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 5, 2022
@leslie-fang-intel leslie-fang-intel marked this pull request as draft December 5, 2022 06:47
@leslie-fang-intel leslie-fang-intel changed the title If the ref module has multi args, remove all the dequant node as the input [Quant] Remove all the dequant nodes when the ref module has multi input args. Dec 5, 2022
@leslie-fang-intel leslie-fang-intel changed the title [Quant] Remove all the dequant nodes when the ref module has multi input args. [Quant] Remove all the dequant nodes when the ref module has multi input args Dec 5, 2022
@leslie-fang-intel leslie-fang-intel added intel This tag is for PR from Intel oncall: quantization Quantization support in PyTorch labels Dec 5, 2022
…as multi input args"


**Summary**:
When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the  `_lower_static_weighted_ref_module` pass.

**Test Plan**:
It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly.


cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
…as multi input args"


**Summary**:
When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the  `_lower_static_weighted_ref_module` pass.

**Test Plan**:
It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly.


cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@leslie-fang-intel
Copy link
Collaborator Author

@jerryzh168 I think this PR is ready for review, could you help to take a look?

…as multi input args"


**Summary**:
When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the  `_lower_static_weighted_ref_module` pass.

**Test Plan**:
It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly.


cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
…as multi input args"


**Summary**:
When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the  `_lower_static_weighted_ref_module` pass.

**Test Plan**:
It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly.


cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
…as multi input args"


**Summary**:
When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the  `_lower_static_weighted_ref_module` pass.

**Test Plan**:
It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly.


cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
leslie-fang-intel added a commit to leslie-fang-intel/pytorch that referenced this pull request Dec 19, 2022
@leslie-fang-intel
Copy link
Collaborator Author

Hi @jerryzh168, Could you help to take a look of this fix?

dq_node = ref_node.args[0]
assert(isinstance(dq_node, Node))
ref_node.replace_input_with(dq_node, dq_node.args[0])
for arg in ref_node.args:
Copy link
Contributor

@jerryzh168 jerryzh168 Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this might be a bit hacky, generally we want to identify a specific pattern and just lower that pattern. this function (_lower_static_weighted_ref_module) is assuming that we have "dq -> ref_fp32_module -> q" pattern I think, could you
1). add some checks to this function to make sure this is the case, e.g. check len(ref_node.args) == 1 or something
2). add another lowering function for conv -> add that uses this code?

Copy link
Collaborator Author

@leslie-fang-intel leslie-fang-intel Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions, @jerryzh168. Followed up these 2 steps:

  1. In this PR, I only add the the check of len(ref_node.args) == 1 for this pass ( _lower_static_weighted_ref_module)
  2. I have added another lowering pass for conv -> add named as _lower_static_weighted_ref_module_with_two_dq_inputs in this [Quant][FX] Lower QConvAdd2d for onednn backend #91153. Could you take a look of this lowering pass as we talked about here? Does it look good to you?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jerryzh168, Could you help to take a look of this fix again?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds great, could you update the summary for this PR as well

…as multi input args"


**Summary**:
When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the  `_lower_static_weighted_ref_module` pass.

**Test Plan**:
It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly.


cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
…as multi input args"


**Summary**:
When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the  `_lower_static_weighted_ref_module` pass.

**Test Plan**:
It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly.


cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
leslie-fang-intel added a commit to leslie-fang-intel/pytorch that referenced this pull request Dec 20, 2022
…as multi input args"


**Summary**:
When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. However, when we enable the `conv add` fusion, there will be a extra input node from `add` node besides the original input node from `conv`. Similar as did in the `_lower_quantized_binary_op` pass https://github.com/pytorch/pytorch/blob/41c3b41b92f5019f8d5e2f2846a06b87db01ca4e/torch/ao/quantization/fx/_lower_to_native_backend.py#L766-L775, We should remove all the `dequant` nodes in the  `_lower_static_weighted_ref_module` pass.

**Test Plan**:
It's a bug fix instead of a new feature. When we enable the `conv add` fusion PR later, we will add test cases accordingly.


cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, please update the summary as well

@leslie-fang-intel
Copy link
Collaborator Author

thanks, please update the summary as well

Updated the summary for this PR.

@leslie-fang-intel
Copy link
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

…as multi input args"


**Summary**:
When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. We add a check in this PR to ensure this is the case for `_lower_static_weighted_ref_module` pass.

**Test Plan**:
We only add a check in this PR, there is no new added test case.

cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/leslie-fang-intel/3/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/90157)

pytorchmergebot pushed a commit that referenced this pull request Jan 5, 2023
@leslie-fang-intel
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@leslie-fang-intel
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/leslie-fang-intel/3/head branch June 8, 2023 17:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request intel This tag is for PR from Intel Merged oncall: quantization Quantization support in PyTorch open source release notes: AO frontend release notes: quantization release notes category

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants