update Node.is_impure check if subgraph contains impure ops#166609
update Node.is_impure check if subgraph contains impure ops#166609jazlyn5 wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166609
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cac714d with merge base 85b035c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ad498e7 to
cb317fa
Compare
…166609) Summary: ## Context when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't. For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside. ``` parent graph(): %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {}) return (getitem,) submodule graph(): %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {}) return (add,) ``` when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure. But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded. ``` parent after fold graph(): %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS] return (_fx_const_folded_attrs,) ``` This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure. ## Fix We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check: - if a call_module node calls a GraphModule, - check any call_function nodes are impure ops - recursively check any call_module nodes that call GraphModule If the call_module subgraph has impure ops, return True to `is_impure` Test Plan: added tests to test_fx_const_fold.py Differential Revision: D85798483
|
|
||
| return kernel | ||
|
|
||
| def _fake_tensor( |
There was a problem hiding this comment.
Looks like this is an artifact of ghexport. Not reviewing these changes. If you want to make them go away, along with the linter error, you could try checking out this branch on git and merging with the main branch (git merge main), or using git checkout/git reset to revert these files to the content in the main branch. But then you may have to unlink your PR with the internal diff. Alternatively, commenting @pytorchbot rebase might work.
cb317fa to
99253e4
Compare
…166609) Summary: ## Context when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't. For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside. ``` parent graph(): %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {}) return (getitem,) submodule graph(): %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {}) return (add,) ``` when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure. But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded. ``` parent after fold graph(): %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS] return (_fx_const_folded_attrs,) ``` This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure. ## Fix We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check: - if a call_module node calls a GraphModule, - check any call_function nodes are impure ops - recursively check any call_module nodes that call GraphModule If the call_module subgraph has impure ops, return True to `is_impure` Test Plan: added tests to test_fx_const_fold.py Reviewed By: blaine-rister Differential Revision: D85798483
…166609) Summary: ## Context when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't. For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside. ``` parent graph(): %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {}) return (getitem,) submodule graph(): %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {}) return (add,) ``` when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure. But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded. ``` parent after fold graph(): %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS] return (_fx_const_folded_attrs,) ``` This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure. ## Fix We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check: - if a call_module node calls a GraphModule, - check any call_function nodes are impure ops - recursively check any call_module nodes that call GraphModule If the call_module subgraph has impure ops, return True to `is_impure` Test Plan: added tests to test_fx_const_fold.py Reviewed By: blaine-rister Differential Revision: D85798483
99253e4 to
f654256
Compare
ca60ef5 to
074f845
Compare
…166609) Summary: ## Context when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't. For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside. ``` parent graph(): %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {}) return (getitem,) submodule graph(): %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {}) return (add,) ``` when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure. But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded. ``` parent after fold graph(): %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS] return (_fx_const_folded_attrs,) ``` This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure. ## Fix We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check: - if a call_module node calls a GraphModule, - check any call_function nodes are impure ops - recursively check any call_module nodes that call GraphModule If the call_module subgraph has impure ops, return True to `is_impure` Test Plan: added tests to test_fx_const_fold.py Reviewed By: blaine-rister Differential Revision: D85798483
| if isinstance(target_mod, torch.fx.GraphModule): | ||
| return subgraph_has_impure_ops(target_mod) | ||
| else: | ||
| return getattr(target_mod, "_is_impure", False) |
There was a problem hiding this comment.
Where is this _is_impure coming from ? I would check if the op is mutable, or has tag nondeterministic seeded. The fx const fold you are using is not used in torch.compile path, fwiw.
See:
pytorch/torch/_inductor/constant_folding.py
Lines 214 to 284 in 0d50e5d
There was a problem hiding this comment.
I'm not sure, return getattr(target_mod, "_is_impure", False) was there before, so I just added if graph module check on top of it.
ah yes we use const_fold.split_const_subgraph out of tree, this is before hitting the point where we call torch.compile.
There was a problem hiding this comment.
@eellison for reference, it seems like this is an experimental FX feature. It's used by some passes like DCE and constant folding. https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Node.is_impure
074f845 to
2efb1ac
Compare
…166609) Summary: ## Context when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't. For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside. ``` parent graph(): %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {}) return (getitem,) submodule graph(): %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {}) return (add,) ``` when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure. But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded. ``` parent after fold graph(): %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS] return (_fx_const_folded_attrs,) ``` This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure. ## Fix We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check: - if a call_module node calls a GraphModule, - check any call_function nodes are impure ops - recursively check any call_module nodes that call GraphModule If the call_module subgraph has impure ops, return True to `is_impure` Test Plan: added tests to test_fx_const_fold.py Reviewed By: blaine-rister Differential Revision: D85798483
…166609) Summary: ## Context when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't. For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside. ``` parent graph(): %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {}) return (getitem,) submodule graph(): %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {}) return (add,) ``` when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure. But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded. ``` parent after fold graph(): %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS] return (_fx_const_folded_attrs,) ``` This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure. ## Fix We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check: - if a call_module node calls a GraphModule, - check any call_function nodes are impure ops - recursively check any call_module nodes that call GraphModule If the call_module subgraph has impure ops, return True to `is_impure` Test Plan: added tests to test_fx_const_fold.py Reviewed By: blaine-rister Differential Revision: D85798483
be3c397 to
dbd65b0
Compare
…166609) Summary: ## Context when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't. For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside. ``` parent graph(): %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {}) return (getitem,) submodule graph(): %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {}) return (add,) ``` when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure. But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded. ``` parent after fold graph(): %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS] return (_fx_const_folded_attrs,) ``` This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure. ## Fix We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check: - if a call_module node calls a GraphModule, - check any call_function nodes are impure ops - recursively check any call_module nodes that call GraphModule If the call_module subgraph has impure ops, return True to `is_impure` Test Plan: added tests to test_fx_const_fold.py Reviewed By: blaine-rister Differential Revision: D85798483
| ) | ||
| self._verify_const_fold_mod(mod_folded) | ||
|
|
||
| def test_do_not_fold_impure_subgraph(self): |
There was a problem hiding this comment.
very nit: If these 2 tests only differ at the beginning and end, you might consider deduplicating the middle logic with @parametrize. The arguments could be the forward function to be compiled, e.g. torch.randn(5, 10) + 1, and a boolean saying whether you expect the GM to be folded. Here's an example.
pytorch/test/inductor/test_op_dtype_prop.py
Lines 186 to 188 in 3206677
This is just one option, feel free to ignore it.
dbd65b0 to
2e94bd0
Compare
…166609) Summary: ## Context when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't. For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside. ``` parent graph(): %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {}) return (getitem,) submodule graph(): %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {}) return (add,) ``` when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure. But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded. ``` parent after fold graph(): %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS] return (_fx_const_folded_attrs,) ``` This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure. ## Fix We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check: - if a call_module node calls a GraphModule, - check any call_function nodes are impure ops - recursively check any call_module nodes that call GraphModule If the call_module subgraph has impure ops, return True to `is_impure` Test Plan: added tests to test_fx_const_fold.py Reviewed By: blaine-rister Differential Revision: D85798483
…166609) Summary: ## Context when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't. For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside. ``` parent graph(): %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {}) return (getitem,) submodule graph(): %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {}) return (add,) ``` when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure. But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded. ``` parent after fold graph(): %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS] return (_fx_const_folded_attrs,) ``` This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure. ## Fix We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check: - if a call_module node calls a GraphModule, - check any call_function nodes are impure ops - recursively check any call_module nodes that call GraphModule If the call_module subgraph has impure ops, return True to `is_impure` Test Plan: added tests to test_fx_const_fold.py Reviewed By: blaine-rister Differential Revision: D85798483
2e94bd0 to
cac714d
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary:
## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.
For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
%sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
return (getitem,)
submodule graph():
%randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.
But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
%_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
return (_fx_const_folded_attrs,)
```
This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.
## Fix
We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule
If the call_module subgraph has impure ops, return True to `is_impure`
Test Plan: added tests to test_fx_const_fold.py
Differential Revision: D85798483
Pull Request resolved: #166609
Approved by: https://github.com/blaine-rister
…le in const_fold._inline_module` Summary: pytorch#166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not. This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops. This affects `const_fold.split_const_subgraph`, what this function does is: 1. split a graph into two submodules, one containing all const ops and one containing non-const ops 2. inline the submodule containing non-const ops back to main graph. 3. run dead code elimination to remove the unused non-const submodule. With pr pytorch#166609 step 3 no longer erases the unused module. As an example, exported graph ``` graph(): %x : [num_users=2] = placeholder[target=x] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False}) %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {}) %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {}) return (div,) ``` After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like ``` graph(): %x : [num_users=3] = placeholder[target=x] %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {}) %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {}) %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {}) %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {}) return (div_tensor,) ``` `submod_1` is dangling, unused, and just inlined into the graph. ## Fix This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph. Test Plan: Added a test in `test_fx_const_fold.py`. The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op. With the PR, the module is erased from main graph correctly. Differential Revision: D86056354
…le in const_fold._inline_module` (pytorch#166871) Summary: pytorch#166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not. This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops. This affects `const_fold.split_const_subgraph`, what this function does is: 1. split a graph into two submodules, one containing all const ops and one containing non-const ops 2. inline the submodule containing non-const ops back to main graph. 3. run dead code elimination to remove the unused non-const submodule. With pr pytorch#166609 step 3 no longer erases the unused module. As an example, exported graph ``` graph(): %x : [num_users=2] = placeholder[target=x] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False}) %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {}) %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {}) return (div,) ``` After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like ``` graph(): %x : [num_users=3] = placeholder[target=x] %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {}) %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {}) %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {}) %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {}) return (div_tensor,) ``` `submod_1` is dangling, unused, and just inlined into the graph. ## Fix This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph. Test Plan: Added a test in `test_fx_const_fold.py`. The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op. With the PR, the module is erased from main graph correctly. Differential Revision: D86056354
…le in const_fold._inline_module` (pytorch#166871) Summary: pytorch#166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not. This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops. This affects `const_fold.split_const_subgraph`, what this function does is: 1. split a graph into two submodules, one containing all const ops and one containing non-const ops 2. inline the submodule containing non-const ops back to main graph. 3. run dead code elimination to remove the unused non-const submodule. With pr pytorch#166609 step 3 no longer erases the unused module. As an example, exported graph ``` graph(): %x : [num_users=2] = placeholder[target=x] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False}) %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {}) %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {}) return (div,) ``` After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like ``` graph(): %x : [num_users=3] = placeholder[target=x] %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {}) %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {}) %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {}) %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {}) return (div_tensor,) ``` `submod_1` is dangling, unused, and just inlined into the graph. ## Fix This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph. Test Plan: Added a test in `test_fx_const_fold.py`. The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op. With the PR, the module is erased from main graph correctly. Reviewed By: blaine-rister Differential Revision: D86056354
…le in const_fold._inline_module` (#166871) Summary: #166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not. This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops. This affects `const_fold.split_const_subgraph`, what this function does is: 1. split a graph into two submodules, one containing all const ops and one containing non-const ops 2. inline the submodule containing non-const ops back to main graph. 3. run dead code elimination to remove the unused non-const submodule. With pr #166609 step 3 no longer erases the unused module. As an example, exported graph ``` graph(): %x : [num_users=2] = placeholder[target=x] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False}) %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {}) %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {}) return (div,) ``` After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like ``` graph(): %x : [num_users=3] = placeholder[target=x] %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {}) %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {}) %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {}) %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {}) return (div_tensor,) ``` `submod_1` is dangling, unused, and just inlined into the graph. ## Fix This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph. Test Plan: Added a test in `test_fx_const_fold.py`. The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op. With the PR, the module is erased from main graph correctly. Differential Revision: D86056354 Pull Request resolved: #166871 Approved by: https://github.com/blaine-rister, https://github.com/mlazos
…166609) Summary: ## Context when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't. For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside. ``` parent graph(): %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {}) return (getitem,) submodule graph(): %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {}) return (add,) ``` when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure. But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded. ``` parent after fold graph(): %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS] return (_fx_const_folded_attrs,) ``` This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure. ## Fix We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check: - if a call_module node calls a GraphModule, - check any call_function nodes are impure ops - recursively check any call_module nodes that call GraphModule If the call_module subgraph has impure ops, return True to `is_impure` Test Plan: added tests to test_fx_const_fold.py Differential Revision: D85798483 Pull Request resolved: pytorch#166609 Approved by: https://github.com/blaine-rister
…le in const_fold._inline_module` (#166871) Summary: #166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not. This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops. This affects `const_fold.split_const_subgraph`, what this function does is: 1. split a graph into two submodules, one containing all const ops and one containing non-const ops 2. inline the submodule containing non-const ops back to main graph. 3. run dead code elimination to remove the unused non-const submodule. With pr #166609 step 3 no longer erases the unused module. As an example, exported graph ``` graph(): %x : [num_users=2] = placeholder[target=x] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False}) %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {}) %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {}) return (div,) ``` After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like ``` graph(): %x : [num_users=3] = placeholder[target=x] %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {}) %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {}) %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {}) %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {}) return (div_tensor,) ``` `submod_1` is dangling, unused, and just inlined into the graph. ## Fix This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph. Test Plan: Added a test in `test_fx_const_fold.py`. The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op. With the PR, the module is erased from main graph correctly. Differential Revision: D86056354 Pull Request resolved: #166871 Approved by: https://github.com/blaine-rister, https://github.com/mlazos
… unblock production Summary: pytorch#166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. This in turn changes `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code) > Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects. While this is correct that a submodule containing side-effectful ops is side-effectful and should not be dead code eliminated, some customers rely on the dead code elimination to eliminate submodules that contain impure ops which is the behavior before pytorch#166609 fix. Due to production environment constraints, we have to revert pytorch#166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops. NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating side-effectful submodules, but we can't safely change that today. ## This pr - move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule - added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future. Test Plan: run test_fx_const_fold and all tests pass Differential Revision: D86641994
…nblock production Summary: pytorch#166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. Prior to pytorch#166609, we always consider a submodule as pure, regardless of what ops contained inside. pytorch#166609 fixed that impurity gap, and in turn changed `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code) > Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects. As a result, dead code elimination may no longer remove *some* unused submodules if they have impure ops. While this is logically correct, some customers rely on the dead code elimination to eliminate impure submodules. Due to such production environment constraints, we have to revert pytorch#166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops. NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating impure submodules, but we can't safely change that today. ## This pr - move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule - added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future. Test Plan: run test_fx_const_fold and all tests pass Differential Revision: D86641994
…nblock production (pytorch#167443) Summary: pytorch#166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. Prior to pytorch#166609, we always consider a submodule as pure, regardless of what ops contained inside. pytorch#166609 fixed that impurity gap, and in turn changed `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code) > Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects. As a result, dead code elimination may no longer remove *some* unused submodules if they have impure ops. While this is logically correct, some customers rely on the dead code elimination to eliminate impure submodules. Due to such production environment constraints, we have to revert pytorch#166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops. NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating impure submodules, but we can't safely change that today. ## This pr - move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule - added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future. Test Plan: run test_fx_const_fold and all tests pass Differential Revision: D86641994
…nblock production Summary: pytorch#166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. Prior to pytorch#166609, we mostly always consider a GraphModule submodule as pure, regardless of if it has impure ops. pytorch#166609 fixed that impurity gap, and in turn changed `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code) > Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects. As a result, dead code elimination may no longer remove *some* unused submodules if they have impure ops. While this is logically correct, some customers rely on the dead code elimination to eliminate impure submodules. Due to such production environment constraints, we have to revert pytorch#166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops. NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating impure submodules, but we can't safely change that today. ## This pr - move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule - added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future. Test Plan: run test_fx_const_fold and all tests pass Differential Revision: D86641994
…nblock production (pytorch#167443) Summary: pytorch#166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. Prior to pytorch#166609, we mostly always consider a GraphModule submodule as pure, regardless of if it has impure ops. pytorch#166609 fixed that impurity gap, and in turn changed `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code) > Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects. As a result, dead code elimination may no longer remove *some* unused submodules if they have impure ops. While this is logically correct, some customers rely on the dead code elimination to eliminate impure submodules. Due to such production environment constraints, we have to revert pytorch#166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops. NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating impure submodules, but we can't safely change that today. ## This pr - move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule - added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future. Test Plan: run test_fx_const_fold and all tests pass Differential Revision: D86641994
… unblock production (#167443) Summary: #166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. This in turn changes `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code) > Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects. While this is correct that a submodule containing side-effectful ops is side-effectful and should not be dead code eliminated, some customers rely on the dead code elimination to eliminate submodules that contain impure ops which is the behavior before #166609 fix. Due to production environment constraints, we have to revert #166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops. NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating side-effectful submodules, but we can't safely change that today. ## This pr - move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule - added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future. Test Plan: run test_fx_const_fold and all tests pass Differential Revision: D86641994 Pull Request resolved: #167443 Approved by: https://github.com/jfix71
… unblock production (pytorch#167443) Summary: pytorch#166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. This in turn changes `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code) > Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects. While this is correct that a submodule containing side-effectful ops is side-effectful and should not be dead code eliminated, some customers rely on the dead code elimination to eliminate submodules that contain impure ops which is the behavior before pytorch#166609 fix. Due to production environment constraints, we have to revert pytorch#166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops. NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating side-effectful submodules, but we can't safely change that today. ## This pr - move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule - added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future. Test Plan: run test_fx_const_fold and all tests pass Differential Revision: D86641994 Pull Request resolved: pytorch#167443 Approved by: https://github.com/jfix71
Summary:
Context
when
const_fold.split_const_subgraphssees acall_modulenode that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.For example, a parent graph contains a
call_moduleto a subgraph that has no inputs but contain impure ops inside.when
submodulegraph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.But if the
submoduleis called by aparentgraph, whenparentis fed to const_fold.split_const_subgraph, it would come out folded.This is because
node.is_impure()check insideconst_fold.split_const_subgraphfail through, leading the call_module node to be marked as pure.Fix
We can update
fx.node.Node.is_impurefunction to check for ops inside a call_module node with an additionalsubgraph_has_impure_opscheck:If the call_module subgraph has impure ops, return True to
is_impureTest Plan: added tests to test_fx_const_fold.py
Differential Revision: D85798483
cc @ezyang @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben