Skip to content

update Node.is_impure check if subgraph contains impure ops#166609

Closed
jazlyn5 wants to merge 1 commit intopytorch:mainfrom
jazlyn5:export-D85798483
Closed

update Node.is_impure check if subgraph contains impure ops#166609
jazlyn5 wants to merge 1 commit intopytorch:mainfrom
jazlyn5:export-D85798483

Conversation

@jazlyn5
Copy link
Contributor

@jazlyn5 jazlyn5 commented Oct 30, 2025

Summary:

Context

when const_fold.split_const_subgraphs sees a call_module node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a call_module to a subgraph that has no inputs but contain impure ops inside.

parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)

when submodule graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the submodule is called by a parent graph, when parent is fed to const_fold.split_const_subgraph, it would come out folded.

parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)

This is because node.is_impure() check inside const_fold.split_const_subgraph fail through, leading the call_module node to be marked as pure.

Fix

We can update fx.node.Node.is_impure function to check for ops inside a call_module node with an additional subgraph_has_impure_ops check:

  • if a call_module node calls a GraphModule,
  • check any call_function nodes are impure ops
  • recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to is_impure

Test Plan: added tests to test_fx_const_fold.py

Differential Revision: D85798483

cc @ezyang @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166609

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cac714d with merge base 85b035c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-codesync
Copy link

meta-codesync bot commented Oct 30, 2025

@jazlyn5 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D85798483.

jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Oct 30, 2025
…166609)

Summary:


## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)
```

This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.

## Fix

We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to `is_impure`

Test Plan: added tests to test_fx_const_fold.py

Differential Revision: D85798483

return kernel

def _fake_tensor(
Copy link
Contributor

@blaine-rister blaine-rister Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is an artifact of ghexport. Not reviewing these changes. If you want to make them go away, along with the linter error, you could try checking out this branch on git and merging with the main branch (git merge main), or using git checkout/git reset to revert these files to the content in the main branch. But then you may have to unlink your PR with the internal diff. Alternatively, commenting @pytorchbot rebase might work.

Copy link
Contributor

@blaine-rister blaine-rister left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(deleted)

Copy link
Contributor

@blaine-rister blaine-rister left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 30, 2025
jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Oct 30, 2025
…166609)

Summary:


## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)
```

This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.

## Fix

We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to `is_impure`

Test Plan: added tests to test_fx_const_fold.py

Reviewed By: blaine-rister

Differential Revision: D85798483
jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Oct 30, 2025
…166609)

Summary:


## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)
```

This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.

## Fix

We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to `is_impure`

Test Plan: added tests to test_fx_const_fold.py

Reviewed By: blaine-rister

Differential Revision: D85798483
jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Oct 30, 2025
…166609)

Summary:


## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)
```

This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.

## Fix

We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to `is_impure`

Test Plan: added tests to test_fx_const_fold.py

Reviewed By: blaine-rister

Differential Revision: D85798483
if isinstance(target_mod, torch.fx.GraphModule):
return subgraph_has_impure_ops(target_mod)
else:
return getattr(target_mod, "_is_impure", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this _is_impure coming from ? I would check if the op is mutable, or has tag nondeterministic seeded. The fx const fold you are using is not used in torch.compile path, fwiw.

See:

# TODO - fix errors with this
if (
node.op == "call_function"
and node.target == aten._efficientzerotensor.default
):
return self.unknown_value
# TODO - constant folding triton kernel returns the inputs -- fix this
if (
node.op == "call_function"
and node.name == "triton_kernel_wrapper_functional_proxy"
):
return self.unknown_value
# skip constructors, since inductor generates optimal code for them already
# and turning into tensor would result in an additional global memory read
# TODO - more complicated strategy
if (
self.skip_constructors
and not is_const_source(node, self.lifted_constant_names)
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
):
return self.unknown_value
# All mutations should either be removed or on inputs which we did not make constant
if (
isinstance(node.target, torch._ops.OpOverload)
and torch.Tag.nondeterministic_seeded in node.target.tags
):
return self.unknown_value
if node.op == "call_function" and isinstance(
node.target, torch._ops.HigherOrderOperator
):
return self.unknown_value
out = self._deduce_value(node)
if isinstance(out, torch._C.ScriptObject):
return out
if out == self.unknown_value:
return self.unknown_value
if not is_const_source(node, self.lifted_constant_names) and (
isinstance(out, torch.Tensor) or out == self.deferred_value
):
if out != self.deferred_value and out.device.type == "meta":
return out
if not self.insertable_tensor_check(out):
return out
if self.is_impure(node):
return self.unknown_value
self.add_node_replacement(node, out)
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
for n in flattened_node_inps:
if not isinstance(n, torch.fx.Node):
continue
self.replaced_uses[n] += 1
for to_delete in self.user_to_last_uses.get(node, []):
if self.replaced_uses[to_delete] == len(to_delete.users):
self.node_replacements.pop(to_delete, None)
return out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, return getattr(target_mod, "_is_impure", False) was there before, so I just added if graph module check on top of it.

ah yes we use const_fold.split_const_subgraph out of tree, this is before hitting the point where we call torch.compile.

Copy link
Contributor

@blaine-rister blaine-rister Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison for reference, it seems like this is an experimental FX feature. It's used by some passes like DCE and constant folding. https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Node.is_impure

jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Oct 30, 2025
…166609)

Summary:


## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)
```

This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.

## Fix

We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to `is_impure`

Test Plan: added tests to test_fx_const_fold.py

Reviewed By: blaine-rister

Differential Revision: D85798483
jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Oct 30, 2025
…166609)

Summary:


## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)
```

This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.

## Fix

We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to `is_impure`

Test Plan: added tests to test_fx_const_fold.py

Reviewed By: blaine-rister

Differential Revision: D85798483
@jazlyn5 jazlyn5 force-pushed the export-D85798483 branch 2 times, most recently from be3c397 to dbd65b0 Compare October 30, 2025 22:57
jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Oct 30, 2025
…166609)

Summary:


## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)
```

This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.

## Fix

We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to `is_impure`

Test Plan: added tests to test_fx_const_fold.py

Reviewed By: blaine-rister

Differential Revision: D85798483
)
self._verify_const_fold_mod(mod_folded)

def test_do_not_fold_impure_subgraph(self):
Copy link
Contributor

@blaine-rister blaine-rister Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nit: If these 2 tests only differ at the beginning and end, you might consider deduplicating the middle logic with @parametrize. The arguments could be the forward function to be compiled, e.g. torch.randn(5, 10) + 1, and a boolean saying whether you expect the GM to be folded. Here's an example.

@parametrize("op_name", OpDtypeSupport.supported_dtypes)
@parametrize("load_upcast_to_fp32", [False, True])
@parametrize("input_dtype", [torch.float16, torch.bfloat16])

This is just one option, feel free to ignore it.

jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Oct 31, 2025
…166609)

Summary:


## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)
```

This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.

## Fix

We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to `is_impure`

Test Plan: added tests to test_fx_const_fold.py

Reviewed By: blaine-rister

Differential Revision: D85798483
…166609)

Summary:


## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)
```

This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.

## Fix

We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to `is_impure`

Test Plan: added tests to test_fx_const_fold.py

Reviewed By: blaine-rister

Differential Revision: D85798483
@jazlyn5
Copy link
Contributor Author

jazlyn5 commented Oct 31, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

BoyuanFeng pushed a commit that referenced this pull request Oct 31, 2025
Summary:
## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)
```

This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.

## Fix

We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to `is_impure`

Test Plan: added tests to test_fx_const_fold.py

Differential Revision: D85798483

Pull Request resolved: #166609
Approved by: https://github.com/blaine-rister
jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Nov 3, 2025
…le in const_fold._inline_module`

Summary:
pytorch#166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not.

This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops.

This affects `const_fold.split_const_subgraph`, what this function does is:
1. split a graph into two submodules, one containing all const ops and one containing non-const ops
2. inline the submodule containing non-const ops back to main graph.
3. run dead code elimination to remove the unused non-const submodule.

With pr pytorch#166609 step 3 no longer erases the unused module. As an example, exported graph
```
 graph():
    %x : [num_users=2] = placeholder[target=x]
    %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
    %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False})
    %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {})
    return (div,)
```
After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like
```
graph():
    %x : [num_users=3] = placeholder[target=x]
    %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
    %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {})
    %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {})
    %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {})
    %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {})
    return (div_tensor,)
```
`submod_1` is dangling, unused, and just inlined into the graph.

## Fix
This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph.

Test Plan:
Added a test in `test_fx_const_fold.py`. 

The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op.

With the PR, the module is erased from main graph correctly.

Differential Revision: D86056354
jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Nov 3, 2025
…le in const_fold._inline_module` (pytorch#166871)

Summary:

pytorch#166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not.

This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops.

This affects `const_fold.split_const_subgraph`, what this function does is:
1. split a graph into two submodules, one containing all const ops and one containing non-const ops
2. inline the submodule containing non-const ops back to main graph.
3. run dead code elimination to remove the unused non-const submodule.

With pr pytorch#166609 step 3 no longer erases the unused module. As an example, exported graph
```
 graph():
    %x : [num_users=2] = placeholder[target=x]
    %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
    %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False})
    %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {})
    return (div,)
```
After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like
```
graph():
    %x : [num_users=3] = placeholder[target=x]
    %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
    %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {})
    %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {})
    %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {})
    %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {})
    return (div_tensor,)
```
`submod_1` is dangling, unused, and just inlined into the graph.

## Fix
This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph.

Test Plan:
Added a test in `test_fx_const_fold.py`. 

The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op.

With the PR, the module is erased from main graph correctly.

Differential Revision: D86056354
jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Nov 3, 2025
…le in const_fold._inline_module` (pytorch#166871)

Summary:

pytorch#166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not.

This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops.

This affects `const_fold.split_const_subgraph`, what this function does is:
1. split a graph into two submodules, one containing all const ops and one containing non-const ops
2. inline the submodule containing non-const ops back to main graph.
3. run dead code elimination to remove the unused non-const submodule.

With pr pytorch#166609 step 3 no longer erases the unused module. As an example, exported graph
```
 graph():
    %x : [num_users=2] = placeholder[target=x]
    %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
    %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False})
    %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {})
    return (div,)
```
After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like
```
graph():
    %x : [num_users=3] = placeholder[target=x]
    %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
    %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {})
    %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {})
    %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {})
    %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {})
    return (div_tensor,)
```
`submod_1` is dangling, unused, and just inlined into the graph.

## Fix
This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph.

Test Plan:
Added a test in `test_fx_const_fold.py`. 

The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op.

With the PR, the module is erased from main graph correctly.

Reviewed By: blaine-rister

Differential Revision: D86056354
pytorchmergebot pushed a commit that referenced this pull request Nov 3, 2025
…le in const_fold._inline_module` (#166871)

Summary:
#166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not.

This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops.

This affects `const_fold.split_const_subgraph`, what this function does is:
1. split a graph into two submodules, one containing all const ops and one containing non-const ops
2. inline the submodule containing non-const ops back to main graph.
3. run dead code elimination to remove the unused non-const submodule.

With pr #166609 step 3 no longer erases the unused module. As an example, exported graph
```
 graph():
    %x : [num_users=2] = placeholder[target=x]
    %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
    %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False})
    %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {})
    return (div,)
```
After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like
```
graph():
    %x : [num_users=3] = placeholder[target=x]
    %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
    %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {})
    %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {})
    %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {})
    %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {})
    return (div_tensor,)
```
`submod_1` is dangling, unused, and just inlined into the graph.

## Fix
This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph.

Test Plan:
Added a test in `test_fx_const_fold.py`.

The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op.

With the PR, the module is erased from main graph correctly.

Differential Revision: D86056354

Pull Request resolved: #166871
Approved by: https://github.com/blaine-rister, https://github.com/mlazos
etaf pushed a commit to etaf/pytorch-inductor-xpu that referenced this pull request Nov 4, 2025
…166609)

Summary:
## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.

For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
    %sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
    return (getitem,)

submodule graph():
    %randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
    return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.

But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
    %_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    return (_fx_const_folded_attrs,)
```

This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.

## Fix

We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule

If the call_module subgraph has impure ops, return True to `is_impure`

Test Plan: added tests to test_fx_const_fold.py

Differential Revision: D85798483

Pull Request resolved: pytorch#166609
Approved by: https://github.com/blaine-rister
pytorch-bot bot pushed a commit that referenced this pull request Nov 4, 2025
…le in const_fold._inline_module` (#166871)

Summary:
#166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not.

This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops.

This affects `const_fold.split_const_subgraph`, what this function does is:
1. split a graph into two submodules, one containing all const ops and one containing non-const ops
2. inline the submodule containing non-const ops back to main graph.
3. run dead code elimination to remove the unused non-const submodule.

With pr #166609 step 3 no longer erases the unused module. As an example, exported graph
```
 graph():
    %x : [num_users=2] = placeholder[target=x]
    %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
    %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False})
    %bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {})
    return (div,)
```
After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like
```
graph():
    %x : [num_users=3] = placeholder[target=x]
    %_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
    %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
    %bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {})
    %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {})
    %div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {})
    %submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {})
    return (div_tensor,)
```
`submod_1` is dangling, unused, and just inlined into the graph.

## Fix
This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph.

Test Plan:
Added a test in `test_fx_const_fold.py`.

The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op.

With the PR, the module is erased from main graph correctly.

Differential Revision: D86056354

Pull Request resolved: #166871
Approved by: https://github.com/blaine-rister, https://github.com/mlazos
jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Nov 10, 2025
… unblock production

Summary:
pytorch#166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. This in turn changes `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code)
> Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects. 


While this is correct that a submodule containing side-effectful ops is side-effectful and should not be dead code eliminated, some customers rely on the dead code elimination to eliminate submodules that contain impure ops which is the behavior before pytorch#166609 fix.

Due to production environment constraints, we have to revert pytorch#166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops.

NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating side-effectful submodules, but we can't safely change that today.


## This pr
- move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule 
- added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future.

Test Plan: run test_fx_const_fold and all tests pass

Differential Revision: D86641994
jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Nov 10, 2025
…nblock production

Summary:
pytorch#166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. Prior to pytorch#166609, we always consider a submodule as pure, regardless of what ops contained inside. 


pytorch#166609 fixed that impurity gap, and in turn changed `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code)
> Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects. 

As a result, dead code elimination may no longer remove *some* unused submodules if they have impure ops. While this is logically correct, some customers rely on the dead code elimination to eliminate impure submodules.

Due to such production environment constraints, we have to revert pytorch#166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops.

NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating impure submodules, but we can't safely change that today.


## This pr
- move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule 
- added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future.

Test Plan: run test_fx_const_fold and all tests pass

Differential Revision: D86641994
jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Nov 10, 2025
…nblock production (pytorch#167443)

Summary:

pytorch#166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. Prior to pytorch#166609, we always consider a submodule as pure, regardless of what ops contained inside. 


pytorch#166609 fixed that impurity gap, and in turn changed `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code)
> Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects. 

As a result, dead code elimination may no longer remove *some* unused submodules if they have impure ops. While this is logically correct, some customers rely on the dead code elimination to eliminate impure submodules.

Due to such production environment constraints, we have to revert pytorch#166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops.

NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating impure submodules, but we can't safely change that today.


## This pr
- move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule 
- added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future.

Test Plan: run test_fx_const_fold and all tests pass

Differential Revision: D86641994
jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Nov 10, 2025
…nblock production

Summary:
pytorch#166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. Prior to pytorch#166609, we mostly always consider a GraphModule submodule as pure, regardless of if it has impure ops. 


pytorch#166609 fixed that impurity gap, and in turn changed `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code)
> Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects. 

As a result, dead code elimination may no longer remove *some* unused submodules if they have impure ops. While this is logically correct, some customers rely on the dead code elimination to eliminate impure submodules.

Due to such production environment constraints, we have to revert pytorch#166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops.

NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating impure submodules, but we can't safely change that today.


## This pr
- move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule 
- added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future.

Test Plan: run test_fx_const_fold and all tests pass

Differential Revision: D86641994
jazlyn5 added a commit to jazlyn5/pytorch that referenced this pull request Nov 10, 2025
…nblock production (pytorch#167443)

Summary:

pytorch#166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. Prior to pytorch#166609, we mostly always consider a GraphModule submodule as pure, regardless of if it has impure ops. 


pytorch#166609 fixed that impurity gap, and in turn changed `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code)
> Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects. 

As a result, dead code elimination may no longer remove *some* unused submodules if they have impure ops. While this is logically correct, some customers rely on the dead code elimination to eliminate impure submodules.

Due to such production environment constraints, we have to revert pytorch#166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops.

NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating impure submodules, but we can't safely change that today.


## This pr
- move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule 
- added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future.

Test Plan: run test_fx_const_fold and all tests pass

Differential Revision: D86641994
pytorchmergebot pushed a commit that referenced this pull request Nov 10, 2025
… unblock production (#167443)

Summary:
#166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. This in turn changes `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code)
> Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects.

While this is correct that a submodule containing side-effectful ops is side-effectful and should not be dead code eliminated, some customers rely on the dead code elimination to eliminate submodules that contain impure ops which is the behavior before #166609 fix.

Due to production environment constraints, we have to revert #166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops.

NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating side-effectful submodules, but we can't safely change that today.

## This pr
- move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule
- added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future.

Test Plan: run test_fx_const_fold and all tests pass

Differential Revision: D86641994

Pull Request resolved: #167443
Approved by: https://github.com/jfix71
Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
… unblock production (pytorch#167443)

Summary:
pytorch#166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. This in turn changes `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code)
> Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects.

While this is correct that a submodule containing side-effectful ops is side-effectful and should not be dead code eliminated, some customers rely on the dead code elimination to eliminate submodules that contain impure ops which is the behavior before pytorch#166609 fix.

Due to production environment constraints, we have to revert pytorch#166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops.

NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating side-effectful submodules, but we can't safely change that today.

## This pr
- move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule
- added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future.

Test Plan: run test_fx_const_fold and all tests pass

Differential Revision: D86641994

Pull Request resolved: pytorch#167443
Approved by: https://github.com/jfix71
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants