explicitly remove call_mod_node_to_replace after inlining the submodule in const_fold._inline_module`#166871
explicitly remove call_mod_node_to_replace after inlining the submodule in const_fold._inline_module`#166871jazlyn5 wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166871
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 12a0499 with merge base aa4a8c9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
c1fbf22 to
ddd68dc
Compare
…le in const_fold._inline_module` (pytorch#166871) Summary: pytorch#166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not. This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops. This affects `const_fold.split_const_subgraph`, what this function does is: 1. split a graph into two submodules, one containing all const ops and one containing non-const ops 2. inline the submodule containing non-const ops back to main graph. 3. run dead code elimination to remove the unused non-const submodule. With pr pytorch#166609 step 3 no longer erases the unused module. As an example, exported graph ``` graph(): %x : [num_users=2] = placeholder[target=x] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False}) %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {}) %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {}) return (div,) ``` After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like ``` graph(): %x : [num_users=3] = placeholder[target=x] %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {}) %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {}) %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {}) %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {}) return (div_tensor,) ``` `submod_1` is dangling, unused, and just inlined into the graph. ## Fix This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph. Test Plan: Added a test in `test_fx_const_fold.py`. The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op. With the PR, the module is erased from main graph correctly. Differential Revision: D86056354
torch/fx/experimental/const_fold.py
Outdated
| # Explicitly remove the module that was just inlined, | ||
| # there should not be any users, but check just in case. | ||
| if len(call_mod_node_to_replace.users) == 0: | ||
| gm.graph.erase_node(call_mod_node_to_replace) |
There was a problem hiding this comment.
Should we assert this if we expect it to always be true? Although that may be redundant, as I have a feeling gm.graph.erase_node will raise an error if the node being removed has users.
| # Explicitly remove the module that was just inlined, | |
| # there should not be any users, but check just in case. | |
| if len(call_mod_node_to_replace.users) == 0: | |
| gm.graph.erase_node(call_mod_node_to_replace) | |
| # Remove the module that was just inlined. | |
| assert len(call_mod_node_to_replace.users) == 0, "Failed to erase inlined submodule because it is still in use!" | |
| gm.graph.erase_node(call_mod_node_to_replace) |
I guess it depends on whether the pass would still function properly if this were false. Would the transform be sound if we didn't erase the node, or would we end up with 2 copies of the module, one inlined and one outlined?
There was a problem hiding this comment.
So the transformed graph is still runnable if you don't remove the unused call_module node. But we certainly do expect to erase the node from the logic in _inline_module and _verify_const_fold_mod in test_fx_const_fold.py.
I have a feeling gm.graph.erase_node will raise an error if we try to remove a node with users
yup tested this is correct, erasing a node being used will raise error.
Should we assert this if we expect it to always be true?
from my understanding and code read yes we expect len(call_mod_node_to_replace.users) == 0 . I was being extra cautious with an if check just in case, since the graph is still runnable if we don't end up removing this module.
blaine-rister
left a comment
There was a problem hiding this comment.
Mostly LGTM. Stamping this to unblock, but please see my comment below before merging.
…le in const_fold._inline_module` (pytorch#166871) Summary: pytorch#166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not. This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops. This affects `const_fold.split_const_subgraph`, what this function does is: 1. split a graph into two submodules, one containing all const ops and one containing non-const ops 2. inline the submodule containing non-const ops back to main graph. 3. run dead code elimination to remove the unused non-const submodule. With pr pytorch#166609 step 3 no longer erases the unused module. As an example, exported graph ``` graph(): %x : [num_users=2] = placeholder[target=x] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False}) %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {}) %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {}) return (div,) ``` After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like ``` graph(): %x : [num_users=3] = placeholder[target=x] %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {}) %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {}) %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {}) %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {}) return (div_tensor,) ``` `submod_1` is dangling, unused, and just inlined into the graph. ## Fix This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph. Test Plan: Added a test in `test_fx_const_fold.py`. The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op. With the PR, the module is erased from main graph correctly. Reviewed By: blaine-rister Differential Revision: D86056354
ddd68dc to
12a0499
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…le in const_fold._inline_module` (#166871) Summary: #166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not. This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops. This affects `const_fold.split_const_subgraph`, what this function does is: 1. split a graph into two submodules, one containing all const ops and one containing non-const ops 2. inline the submodule containing non-const ops back to main graph. 3. run dead code elimination to remove the unused non-const submodule. With pr #166609 step 3 no longer erases the unused module. As an example, exported graph ``` graph(): %x : [num_users=2] = placeholder[target=x] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False}) %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {}) %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {}) return (div,) ``` After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like ``` graph(): %x : [num_users=3] = placeholder[target=x] %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {}) %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {}) %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {}) %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {}) return (div_tensor,) ``` `submod_1` is dangling, unused, and just inlined into the graph. ## Fix This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph. Test Plan: Added a test in `test_fx_const_fold.py`. The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op. With the PR, the module is erased from main graph correctly. Differential Revision: D86056354 Pull Request resolved: #166871 Approved by: https://github.com/blaine-rister, https://github.com/mlazos
Summary:
#166609 updated
is_impurecheck to now check ops inside a subgraph to decide whether acall_modulenode is pure or not.This change of behavior affects dead code elimination, commonly run as
gm.graph.eliminate_dead_code(). Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops.This affects
const_fold.split_const_subgraph, what this function does is:With pr #166609 step 3 no longer erases the unused module. As an example, exported graph
After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like
submod_1is dangling, unused, and just inlined into the graph.Fix
This pr updates
const_fold._inline_modulefunction to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph.Test Plan:
Added a test in
test_fx_const_fold.py.The test would have failed before this PR becuase it yields above example graph leaving an unused
call_module[target=submod_1]op.With the PR, the module is erased from main graph correctly.
Differential Revision: D86056354
cc @ezyang @EikanWang @jgong5 @wenzhe-nrv