Skip to content

Inductor Lite Mode#167115

Closed
BoyuanFeng wants to merge 28 commits intomainfrom
bf/lite
Closed

Inductor Lite Mode#167115
BoyuanFeng wants to merge 28 commits intomainfrom
bf/lite

Conversation

@BoyuanFeng
Copy link
Contributor

@BoyuanFeng BoyuanFeng commented Nov 5, 2025

This PR introduces inductor lite mode for opt-in optimizations and numeric correctness guarantees.

Different from default mode that applies all possible fusions, lite mode gives the control back to user and provides guarantee on numeric correctness. Specifically, this mode:

  • Fallback by Default: Fallback for ALL nodes by default, unless users explicitly mark node for inductor fusion.
  • Selective Decomposition: Skip decomposition for all nodes except for user marked nodes.
  • Regional inductor compile
  • Skip dead code elimination
  • Skip buffer reues
  • Skip reorder passes, such as reorder for peak memory, reorder for compute comm overlap, and reorder_for_reducing_graph_partitions.
  • Skip all pre-grad, joint-graph, and post-grad passes.

Example: Flex Attention

import torch
import torch.fx.traceback as fx_traceback
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

def _squared(score, b, h, m, n):
    return score * score

def mask_mod(b, h, q, k):
    return q >= 0

a, b = 12, 64
block_mask = create_block_mask(mask_mod, None, None, a * b, a * b, device="cuda")

def fn(x):
    x = torch.sin(x)
    with fx_traceback.annotate({"compile_with_inductor": 0}):
        x = flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared)
    return torch.cos(x)

x = torch.randn(1, 1, a * b, b, dtype=torch.bfloat16, device="cuda", requires_grad=True)

opt_fn = torch.compile(fn, mode="lite", fullgraph=True,)
opt_fn(x)

code diff

default mode tlp vs lite mode tlp

Numerics

Inductor lite mode provides bitwise equivalence with aot_eager backend on torchtitan llama3-8b and DeepSeek v3. pytorch/torchtitan#2005

close: #167012

cc @ezyang @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167115

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 3 Unrelated Failures

As of commit 8690e80 with merge base ce8672c (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor release notes: fx release notes category labels Nov 5, 2025
@BoyuanFeng BoyuanFeng marked this pull request as draft November 5, 2025 19:45
@ezyang
Copy link
Contributor

ezyang commented Nov 6, 2025

How does this relate to the cudagraph PR?

@BoyuanFeng
Copy link
Contributor Author

@ezyang I plan to land the inductor lite mode first without cudagraph. It will be a codegen and supports regional inductor compile. The purpose is to unblock other issues such as #167012.

@BoyuanFeng BoyuanFeng requested a review from angelayi November 6, 2025 06:51
@BoyuanFeng BoyuanFeng marked this pull request as ready for review November 7, 2025 02:01
"use_pre_grad_passes": False,
"use_joint_graph_passes": False,
"use_post_grad_passes": False,
"use_decomposition": False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a long term maintenance question here. Essentially, I expect most development in Inductor to be operating on non-lite mode. How do we ensure that optimizations do not accidentally get "turned on" for lite mode users? Seeing that so many specific optimizations have to be toggled makes me worried that it will be easy for someone to add a new pass and forget to make sure it doesn't apply in lite mode. The extra problem is that if a pass only applies in certain situations, it is possible for a pass to be silently enabled and we will only find out when user code changes to trigger the pass that we hadn't properly excluded it from lite mode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In particular, the new tests in test_torchinductor.py don't /feel/ sufficient. But I am also not sure about a good testing strategy that doesn't result in huge amounts of extra tests having to be run. Maybe there is a lightweight strategy involving asserts/lints that we could use to avoid problems? Or maybe if you think that the Inductor team can be socialized to not mess up lite mode. We should also consider adding lite as a benchmark configuration for our benchmark suite, and fuzz for EXACT bitwise equivalence in the tests there (we traditionally can't do exact bitwise, but we absolutely can for the benchmark models!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are roughly 4 groups of configs. Group 3 needs a bit attention for maintenance..

Group 1: these configs fallback by default and only decompose & fuse for user explicitly annotated nodes.

    "fallback_by_default": True,
    "selective_decompose": True,

Group 2: This turns off all passes and should not require much attention.

    "use_pre_grad_passes": False,
    "use_joint_graph_passes": False,
    "use_post_grad_passes": False,

Group 3: reorder. This needs more maintenance in the future. Developers need to add new reorder name here to turn off.

    "reorder_for_peak_memory": False,
    "reorder_for_compute_comm_overlap": False,
    "triton.reorder_for_reducing_graph_partitions": False,

Group 4: Specifically to skip dce and buffer reuse.

    "use_dce": False,
    "allow_buffer_reuse": False,

# Use fx graph passes
use_pre_grad_passes: bool = True
use_joint_graph_passes: bool = True
use_post_grad_passes: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Someone more familiar with Inductor config than I am should check if there are any "duplicate" settings here.

and isinstance(
n.target, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
)
and should_fallback_by_default(n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused by the choice you made here. You have to force triton kernel wrapper not to go fallback. So why not just swing this branch later in the if-else chain here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another confusion I have is that the condition here is not symmetric with the bisector condition above, maybe there's a good reason but I'd at least like it to be remarked?

Copy link
Contributor Author

@BoyuanFeng BoyuanFeng Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the major reason is to keep it together all fallback decisions due to lite mode. Alternatively, we can implement as two if branches (one for OpOverload symmetric with the bisector, and one for hops fallback). I added more docs.

# The current use case is regional inductor compile, where users specify nodes that
# should be decomposed and compiled with inductor. When true, we skip decomposition
# for all operators except the user specified nodes.
selective_decompose: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I don't understand what this config does. In particular, if this is "selective" decompose, how exactly do I specify if something is to be decomposed or not? The documentation here doesn't say.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config decomposes all nodes annotated in the regional compile context manager AND does not decompose any other nodes. We need it because, for regional inductor compiler, we still need decomposition for these nodes. Let me add more docs here.

Also, would regional_decompose be a better name to match regional_inductor?

joint_gm, should_decompose, decomposition
).run(*args)

return make_fx(wrap_fn, decomposition_table={})(*args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one big question that I have is why we are doing this as a retrace, as opposed to adjusting the decomposition when we were doing the original trace. It would certainly be faster (compile time wise) if we avoided having to do another full make_fx retrace here.

Copy link
Contributor Author

@BoyuanFeng BoyuanFeng Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to adjust the decomposition based on node.meta["custom"]. The easiest way is an interpreter iterating nodes, checking metadata, and deciding decompose or not.

For the original trace, we only have a runnable and don't have the fx graph yet. For retrace, we can use the fx graph from the original trace. Not sure if there are other ways to do that..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you do, when you do an FX trace you're always generating nodes and making a decision to decompose there are not in proxy tensor mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created an issue #167520 and will do in a followup pr.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this should could be done without retracing everything, which is slow.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock, please check comments

@BoyuanFeng
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 11, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m1-14), trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m2-15)

Details for Dev Infra team Raised by workflow job

@BoyuanFeng
Copy link
Contributor Author

@pytorchbot merge -f "skip unrelated test__dyn_quant_matmul_4bit_bf16_input"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, couple comments.

Comment on lines +4053 to +4059
# some ops need special handle due to dynamic shapes. we can avoid
# fallback if they do not impact numerics.
skip_fallback_due_to_dynamic_shape = OrderedSet(
[
torch.ops.aten._assert_scalar.default,
torch.ops.aten.lift_fresh_copy.default,
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment on this ? why ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_lite_dynamic_shape_assertion_cuda gives the following graph:

image

When we fallback for _assert_scalar, ExternKernel would execute it on ge and errors with

example_output = kernel(*new_args, **new_kwargs)

image

More context in #167012.

joint_gm, should_decompose, decomposition
).run(*args)

return make_fx(wrap_fn, decomposition_table={})(*args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this should could be done without retracing everything, which is slow.

Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
This PR introduces inductor lite mode for opt-in optimizations and numeric correctness guarantees.

Different from default mode that applies all possible fusions, lite mode gives the control back to user and provides guarantee on numeric correctness. Specifically, this mode:

- **Fallback by Default**: Fallback for ALL nodes by default, unless users explicitly mark node for inductor fusion.
- **Selective Decomposition**: Skip decomposition for all nodes except for user marked nodes.
- **Regional inductor compile**
- Skip dead code elimination
- Skip buffer reues
- Skip reorder passes, such as reorder for peak memory, reorder for compute comm overlap, and reorder_for_reducing_graph_partitions.
- Skip all pre-grad, joint-graph, and post-grad passes.

## Example: Flex Attention

```python
import torch
import torch.fx.traceback as fx_traceback
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

def _squared(score, b, h, m, n):
    return score * score

def mask_mod(b, h, q, k):
    return q >= 0

a, b = 12, 64
block_mask = create_block_mask(mask_mod, None, None, a * b, a * b, device="cuda")

def fn(x):
    x = torch.sin(x)
    with fx_traceback.annotate({"compile_with_inductor": 0}):
        x = flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared)
    return torch.cos(x)

x = torch.randn(1, 1, a * b, b, dtype=torch.bfloat16, device="cuda", requires_grad=True)

opt_fn = torch.compile(fn, mode="lite", fullgraph=True,)
opt_fn(x)
```

[code diff](https://www.internalfb.com/intern/diffing/?paste_number=2027441476)

[default mode tlp](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpYAzDxX/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) vs [lite mode tlp](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpnnuh1W/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000)

## Numerics

Inductor lite mode provides bitwise equivalence with `aot_eager` backend on torchtitan llama3-8b and DeepSeek v3. pytorch/torchtitan#2005

close: pytorch#167012

Pull Request resolved: pytorch#167115
Approved by: https://github.com/ezyang
@github-actions github-actions bot deleted the bf/lite branch December 13, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fx Merged module: inductor release notes: fx release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

regional inductor error out when input has unbacked symint expressions

6 participants