Skip to content

[Inductor-FX] Support torch.cond#163234

Closed
blaine-rister wants to merge 14 commits intomainfrom
brister/fx_cond
Closed

[Inductor-FX] Support torch.cond#163234
blaine-rister wants to merge 14 commits intomainfrom
brister/fx_cond

Conversation

@blaine-rister
Copy link
Contributor

@blaine-rister blaine-rister commented Sep 18, 2025

Feature

Support torch.cond in the FX converter. The generated FX IR is conceptually indentical to what would come from torch.export:

  • Submodules as stored as attributes, and accessed via getattr.
  • The conditional is represented as torch.ops.higher_order.cond, which takes in the subgraphs, a predicate and submodule inputs.

Implementation overview

The FX backend generates code for subgraphs using the following steps:

  1. When codegen_conditional is called in WrapperFxCodegen, we emit a ConditionalLine.
    a. We also codegen the true/false subgraphs at this time, storing their subgms for later.
  2. At the beginning of FX conversion, generate get_attr nodes accessing each subgraph. It's important to do this at the start, before registering the node metadata hook. This also matches the convention followed by torch.export.
  3. When we see the ConditionalLine in the FX converter, we generate a corresponding torch.ops.higher_order.cond.

Implementation details

This ended up being a substantial change, as wrapper codegen has some special logic for subgraphs.

Certain methods of PythonWrapperCodegen are overridden by SubgraphPythonWrapperCodegen. To apply these overrides, we use multiple inheritance with the registered subclass of WrapperFxCodegen.

Unlike most other wrapper codegen methods, which map 1:1 to Wrapper IR lines, subgraph codegen generates a number of wrapper lines including EnterSubgraphLine and ExitSubgraphLine, along with Python or C++ code calling the subgraph as a function. These lines are used for some backends' memory planning.

In contrast, FX IR typically represents a subgraph call as a single HOP node, or a call_module op. To account for this difference, this PR introduces a new wrapper IR line called ConditionalLine, which is only used by the FX backend. We override the codegen_conditional method to emit this line. This sidesteps having to port the existing subgraph codegen and associated memory planning to Wrapper IR. (In principle, it seems possible to adapt the existing backends to ConditionalLine, but it could be a larger refactor, since we'd also have to update the memory planning.)

Some of the lower-level subgraph codegen methods are still shared between the FX and Python backends, such as generate_subgraph_common. Those were easier to port to Wrapper IR.

This also required generalizing the way the FX converter handles graph inputs and outputs. Previously, it assumed the IO signature was the same as V.graph.module, but this is only true for the parent graph, and not subgraphs. Instead, we need to call get_graph_inputs and get_graph_outputs to populate the inputs and outputs for subgraphs.

Test plan

This PR adds a couple of tests using torch.cond. Here's an example graph generated by one of them:

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%arg0_1, %true_graph_0, %false_graph_0, (%arg1_1,)), kwargs = {})
    %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 6, constant_args_idx: 6, grid: [(1, 1, 1)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 6, XBLOCK: 8}})
    return buf1

It also removes an existing negative test which checked that a certain error was raised when subgraphs were encountered.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163234

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a94d40f with merge base 1aeac30 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

# Subgraphs override some methods of PythonWrapperCodegen.
# Apply these overrides to the user-provided class, with priority given to
# user-provided methods.
class SubgraphFxWrapperCodegen(cls, SubgraphPythonWrapperCodegen):
Copy link
Contributor Author

@blaine-rister blaine-rister Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy doesn't seem to support using a variable as a base class, but it's valid Python. Without this, derived classes would need their own create method, which would need to define a new class to handle subgraphs. This trick lets us handle everything in one class, without custom backends needing to know about subgraphs.


# Register the FX backend.
register_backend_for_device(cls.device, TritonScheduling, WrapperFxCodegen)
# Register the FX backend, storing the default for later.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a slight upgrade to fix local testing. AOTI tests could break if they were run after the torch.compile tests, since the latter modifies device_codegens. This restores the old backend when torch.compile tests conclude.

"""
Get the input nodes corresponding to FX graph placeholders.
"""
if V.aot_compilation and not self.is_subgraph:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this logic over from the FX converter, since we need to call get_graph_inputs for subgraphs.

@blaine-rister blaine-rister marked this pull request as ready for review September 19, 2025 18:07
return wrapper # type: ignore[return-value]


def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes a longstanding issue where cache_on_self's type signature was incompatible with @property. In a follow-up, we can remove a lot of @no_type_check decoratorations like this one.

Copy link
Contributor

@angelayi angelayi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks so much for the detailed PR description! made reviewing the PR a lot easier!

had some minor nits

@blaine-rister
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 20, 2025
@blaine-rister blaine-rister added the topic: not user facing topic category label Sep 20, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@blaine-rister
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
# Feature

Support `torch.cond` in the FX converter. The generated FX IR is conceptually indentical to what would come from `torch.export`:
- Submodules as stored as attributes, and accessed via `getattr`.
- The conditional is represented as `torch.ops.higher_order.cond`, which takes in the subgraphs, a predicate and submodule inputs.

# Implementation overview

The FX backend generates code for subgraphs using the following steps:
1. When `codegen_conditional` is called in `WrapperFxCodegen`, we emit a `ConditionalLine`.
   a. We also codegen the true/false subgraphs at this time, storing their subgms for later.
2. At the beginning of FX conversion, generate `get_attr` nodes accessing each subgraph. It's important to do this at the start, before registering the node metadata hook. This also matches the convention followed by torch.export.
3. When we see the `ConditionalLine` in the FX converter, we generate a corresponding `torch.ops.higher_order.cond`.

# Implementation details
This ended up being a substantial change, as wrapper codegen has some special logic for subgraphs.

Certain methods of `PythonWrapperCodegen` are overridden by `SubgraphPythonWrapperCodegen`. To apply these overrides, we use multiple inheritance with the registered subclass of `WrapperFxCodegen`.

Unlike most other wrapper codegen methods, which map 1:1 to Wrapper IR lines, subgraph codegen generates a number of wrapper lines including `EnterSubgraphLine` and `ExitSubgraphLine`, along with Python or C++ code calling the subgraph as a function. These lines are used for some backends' memory planning.

In contrast, FX IR typically represents a subgraph call as a single HOP node, or a `call_module` op. To account for this difference, this PR introduces a new wrapper IR line called `ConditionalLine`, which is only used by the FX backend. We override the `codegen_conditional` method to emit this line. This sidesteps having to port the existing subgraph codegen and associated memory planning to Wrapper IR. (In principle, it seems possible to adapt the existing backends to `ConditionalLine`, but it could be a larger refactor, since we'd also have to update the memory planning.)

Some of the lower-level subgraph codegen methods are still shared between the FX and Python backends, such as `generate_subgraph_common`. Those were easier to port to Wrapper IR.

This also required generalizing the way the FX converter handles graph inputs and outputs. Previously, it assumed the IO signature was the same as `V.graph.module`, but this is only true for the parent graph, and not subgraphs. Instead, we need to call `get_graph_inputs` and `get_graph_outputs` to populate the inputs and outputs for subgraphs.

# Test plan
This PR adds a couple of tests using torch.cond. Here's an example graph generated by one of them:
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%arg0_1, %true_graph_0, %false_graph_0, (%arg1_1,)), kwargs = {})
    %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 6, constant_args_idx: 6, grid: [(1, 1, 1)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 6, XBLOCK: 8}})
    return buf1
```

It also removes an existing negative test which checked that a certain error was raised when subgraphs were encountered.

Pull Request resolved: pytorch#163234
Approved by: https://github.com/angelayi, https://github.com/jansel
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
# Feature

Support `torch.cond` in the FX converter. The generated FX IR is conceptually indentical to what would come from `torch.export`:
- Submodules as stored as attributes, and accessed via `getattr`.
- The conditional is represented as `torch.ops.higher_order.cond`, which takes in the subgraphs, a predicate and submodule inputs.

# Implementation overview

The FX backend generates code for subgraphs using the following steps:
1. When `codegen_conditional` is called in `WrapperFxCodegen`, we emit a `ConditionalLine`.
   a. We also codegen the true/false subgraphs at this time, storing their subgms for later.
2. At the beginning of FX conversion, generate `get_attr` nodes accessing each subgraph. It's important to do this at the start, before registering the node metadata hook. This also matches the convention followed by torch.export.
3. When we see the `ConditionalLine` in the FX converter, we generate a corresponding `torch.ops.higher_order.cond`.

# Implementation details
This ended up being a substantial change, as wrapper codegen has some special logic for subgraphs.

Certain methods of `PythonWrapperCodegen` are overridden by `SubgraphPythonWrapperCodegen`. To apply these overrides, we use multiple inheritance with the registered subclass of `WrapperFxCodegen`.

Unlike most other wrapper codegen methods, which map 1:1 to Wrapper IR lines, subgraph codegen generates a number of wrapper lines including `EnterSubgraphLine` and `ExitSubgraphLine`, along with Python or C++ code calling the subgraph as a function. These lines are used for some backends' memory planning.

In contrast, FX IR typically represents a subgraph call as a single HOP node, or a `call_module` op. To account for this difference, this PR introduces a new wrapper IR line called `ConditionalLine`, which is only used by the FX backend. We override the `codegen_conditional` method to emit this line. This sidesteps having to port the existing subgraph codegen and associated memory planning to Wrapper IR. (In principle, it seems possible to adapt the existing backends to `ConditionalLine`, but it could be a larger refactor, since we'd also have to update the memory planning.)

Some of the lower-level subgraph codegen methods are still shared between the FX and Python backends, such as `generate_subgraph_common`. Those were easier to port to Wrapper IR.

This also required generalizing the way the FX converter handles graph inputs and outputs. Previously, it assumed the IO signature was the same as `V.graph.module`, but this is only true for the parent graph, and not subgraphs. Instead, we need to call `get_graph_inputs` and `get_graph_outputs` to populate the inputs and outputs for subgraphs.

# Test plan
This PR adds a couple of tests using torch.cond. Here's an example graph generated by one of them:
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%arg0_1, %true_graph_0, %false_graph_0, (%arg1_1,)), kwargs = {})
    %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 6, constant_args_idx: 6, grid: [(1, 1, 1)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 6, XBLOCK: 8}})
    return buf1
```

It also removes an existing negative test which checked that a certain error was raised when subgraphs were encountered.

Pull Request resolved: pytorch#163234
Approved by: https://github.com/angelayi, https://github.com/jansel
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
# Feature

Support `torch.cond` in the FX converter. The generated FX IR is conceptually indentical to what would come from `torch.export`:
- Submodules as stored as attributes, and accessed via `getattr`.
- The conditional is represented as `torch.ops.higher_order.cond`, which takes in the subgraphs, a predicate and submodule inputs.

# Implementation overview

The FX backend generates code for subgraphs using the following steps:
1. When `codegen_conditional` is called in `WrapperFxCodegen`, we emit a `ConditionalLine`.
   a. We also codegen the true/false subgraphs at this time, storing their subgms for later.
2. At the beginning of FX conversion, generate `get_attr` nodes accessing each subgraph. It's important to do this at the start, before registering the node metadata hook. This also matches the convention followed by torch.export.
3. When we see the `ConditionalLine` in the FX converter, we generate a corresponding `torch.ops.higher_order.cond`.

# Implementation details
This ended up being a substantial change, as wrapper codegen has some special logic for subgraphs.

Certain methods of `PythonWrapperCodegen` are overridden by `SubgraphPythonWrapperCodegen`. To apply these overrides, we use multiple inheritance with the registered subclass of `WrapperFxCodegen`.

Unlike most other wrapper codegen methods, which map 1:1 to Wrapper IR lines, subgraph codegen generates a number of wrapper lines including `EnterSubgraphLine` and `ExitSubgraphLine`, along with Python or C++ code calling the subgraph as a function. These lines are used for some backends' memory planning.

In contrast, FX IR typically represents a subgraph call as a single HOP node, or a `call_module` op. To account for this difference, this PR introduces a new wrapper IR line called `ConditionalLine`, which is only used by the FX backend. We override the `codegen_conditional` method to emit this line. This sidesteps having to port the existing subgraph codegen and associated memory planning to Wrapper IR. (In principle, it seems possible to adapt the existing backends to `ConditionalLine`, but it could be a larger refactor, since we'd also have to update the memory planning.)

Some of the lower-level subgraph codegen methods are still shared between the FX and Python backends, such as `generate_subgraph_common`. Those were easier to port to Wrapper IR.

This also required generalizing the way the FX converter handles graph inputs and outputs. Previously, it assumed the IO signature was the same as `V.graph.module`, but this is only true for the parent graph, and not subgraphs. Instead, we need to call `get_graph_inputs` and `get_graph_outputs` to populate the inputs and outputs for subgraphs.

# Test plan
This PR adds a couple of tests using torch.cond. Here's an example graph generated by one of them:
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%arg0_1, %true_graph_0, %false_graph_0, (%arg1_1,)), kwargs = {})
    %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 6, constant_args_idx: 6, grid: [(1, 1, 1)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 6, XBLOCK: 8}})
    return buf1
```

It also removes an existing negative test which checked that a certain error was raised when subgraphs were encountered.

Pull Request resolved: pytorch#163234
Approved by: https://github.com/angelayi, https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor release notes: fx release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants