[inductor] generate fused rms/layer norm bwd#165370
Closed
shunting314 wants to merge 35 commits intogh/shunting314/237/basefrom
Closed
[inductor] generate fused rms/layer norm bwd#165370shunting314 wants to merge 35 commits intogh/shunting314/237/basefrom
shunting314 wants to merge 35 commits intogh/shunting314/237/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165370
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ No Failures, 1 PendingAs of commit 1ac94b7 with merge base 4e6afa8 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
PaulZhang12
reviewed
Oct 24, 2025
RMS/Layer norm backward would generated 2 kind of reductions: - the reduction computing dx which reduce across the hidden dimension (in the context of transformer) - the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension. These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders. There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically. The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 . To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload. To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following: 1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening. 2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction 3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16). Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
shunting314
commented
Oct 24, 2025
RMS/Layer norm backward would generated 2 kind of reductions: - the reduction computing dx which reduce across the hidden dimension (in the context of transformer) - the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension. These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders. There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically. The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 . To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload. To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following: 1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening. 2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction 3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16). Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
jansel
approved these changes
Oct 26, 2025
RMS/Layer norm backward would generated 2 kind of reductions: - the reduction computing dx which reduce across the hidden dimension (in the context of transformer) - the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension. These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders. There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically. The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 . To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload. To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following: 1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening. 2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction 3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16). Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
RMS/Layer norm backward would generated 2 kind of reductions: - the reduction computing dx which reduce across the hidden dimension (in the context of transformer) - the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension. These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders. There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically. The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 . To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload. To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following: 1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening. 2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction 3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16). Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
Closed
RMS/Layer norm backward would generated 2 kind of reductions: - the reduction computing dx which reduce across the hidden dimension (in the context of transformer) - the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension. These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders. There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically. The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 . To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload. To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following: 1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening. 2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction 3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16). Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
shunting314
commented
Oct 27, 2025
| enable_pdl = False | ||
|
|
||
| mix_order_reduction = ( | ||
| os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION", "0") == "1" |
Contributor
Author
There was a problem hiding this comment.
Off by default in this PR. Will reenable in following PRs
Contributor
Author
|
@pytorchbot merge |
Collaborator
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This was referenced Oct 28, 2025
pytorch-bot bot
pushed a commit
that referenced
this pull request
Oct 29, 2025
Signed-off-by: xinan.lin <xinan.lin@intel.com>
pytorchmergebot
pushed a commit
that referenced
this pull request
Oct 30, 2025
…166384) This PR reused native_mm and mix_order_reduction for Intel GPU and enabled the corresonding test. Fixes #165370 Pull Request resolved: #166384 Approved by: https://github.com/jansel
BoyuanFeng
pushed a commit
that referenced
this pull request
Oct 31, 2025
…166384) This PR reused native_mm and mix_order_reduction for Intel GPU and enabled the corresonding test. Fixes #165370 Pull Request resolved: #166384 Approved by: https://github.com/jansel
etaf
added a commit
to etaf/pytorch-inductor-xpu
that referenced
this pull request
Nov 4, 2025
…ytorch#166384) This PR reused native_mm and mix_order_reduction for Intel GPU and enabled the corresonding test. Fixes pytorch#165370 Pull Request resolved: pytorch#166384 Approved by: https://github.com/jansel
Khanaksahu
pushed a commit
to Khanaksahu/pytorch
that referenced
this pull request
Nov 17, 2025
ghstack-source-id: 077edea Pull Request resolved: pytorch/pytorch#165370
Khanaksahu
pushed a commit
to Khanaksahu/pytorch-fork
that referenced
this pull request
Nov 17, 2025
ghstack-source-id: 410dcec Pull Request resolved: pytorch/pytorch#165370
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
RMS/Layer norm backward would generated 2 kind of reductions:
These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders.
There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically.
The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 .
To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload.
To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following:
Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben