Skip to content

[Inductor][Triton][FP8] Support tile-wise (1x128) scaling in Inductor#165132

Closed
jananisriram wants to merge 1 commit intomainfrom
export-D84025878
Closed

[Inductor][Triton][FP8] Support tile-wise (1x128) scaling in Inductor#165132
jananisriram wants to merge 1 commit intomainfrom
export-D84025878

Conversation

@jananisriram
Copy link
Contributor

@jananisriram jananisriram commented Oct 10, 2025

Summary:
Support tile-wise 1x128 scaling in Inductor Triton for FP8 GEMMs, i.e. scaling values along tensors a and b represent a 1x128 slice of input.

NOTE: Block-wise 128x128 and 1x128 scaling is only supported in CUDA 12.9+; therefore, tile-wise scaling is currently unsupported in fbcode (CUDA 12.4). Use OSS PyTorch to run tile-wise scaling (as with deepseek-style scaling).

Test Plan:
Works out-of-the-box with TritonBench:

TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling-pair=BlockWise1x128,BlockWise1x128 --atol=1e-2 --rtol=0.5

Differential Revision: D84025878

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165132

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 Cancelled Job

As of commit 4820cba with merge base 56a809a (image):

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-codesync
Copy link

meta-codesync bot commented Oct 10, 2025

@jananisriram has exported this pull request. If you are a Meta employee, you can view the originating Diff in D84025878.

@PaulZhang12
Copy link
Contributor

Overall, this looks great, will wait for @eqy to comment about the blas changes cc @eellison

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 10, 2025
@drisspg drisspg added ciflow/b200 ciflow/rocm Trigger "default" config CI on ROCm labels Oct 10, 2025
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much cleaner, a couple small comments. Great job!

@jananisriram jananisriram added the topic: not user facing topic category label Oct 10, 2025
jananisriram added a commit that referenced this pull request Oct 10, 2025
…#165132)

Summary:
Pull Request resolved: #165132

Support tile-wise `1x128` scaling in Inductor Triton for FP8 GEMMs, i.e. scaling values along tensors `a` and `b` represent a `1x128` slice of input.

NOTE: Block-wise `128x128` and `1x128` scaling is only supported in CUDA 12.9+; therefore, tile-wise scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run tile-wise scaling (as with deepseek-style scaling).

Test Plan:
Works out-of-the-box with TritonBench:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling-pair=BlockWise1x128,BlockWise1x128 --atol=1e-2 --rtol=0.5
```

Differential Revision: D84025878
Copy link
Contributor

@njriasan njriasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jananisriram See my feedback. Overall this seems good, but I have some suggestions to have a single point of failure for the allowed scaling options. I'll defer to you though regarding what combinations should actually be supported and how to enforce those.

jananisriram added a commit that referenced this pull request Oct 16, 2025
…#165132)

Summary:
Pull Request resolved: #165132

Support tile-wise `1x128` scaling in Inductor Triton for FP8 GEMMs, i.e. scaling values along tensors `a` and `b` represent a `1x128` slice of input.

NOTE: Block-wise `128x128` and `1x128` scaling is only supported in CUDA 12.9+; therefore, tile-wise scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run tile-wise scaling (as with deepseek-style scaling).

Test Plan:
Works out-of-the-box with TritonBench:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling-pair=BlockWise1x128,BlockWise1x128 --atol=1e-2 --rtol=0.5
```

Reviewed By: njriasan

Differential Revision: D84025878
facebook-github-bot pushed a commit that referenced this pull request Oct 27, 2025
…#165132)

Summary:

Support tile-wise `1x128` scaling in Inductor Triton for FP8 GEMMs, i.e. scaling values along tensors `a` and `b` represent a `1x128` slice of input.

NOTE: Block-wise `128x128` and `1x128` scaling is only supported in CUDA 12.9+; therefore, tile-wise scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run tile-wise scaling (as with deepseek-style scaling).

Test Plan:
Works out-of-the-box with TritonBench:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling-pair=BlockWise1x128,BlockWise1x128 --atol=1e-2 --rtol=0.5
```

Reviewed By: njriasan

Differential Revision: D84025878
facebook-github-bot pushed a commit that referenced this pull request Oct 27, 2025
…#165132)

Summary:

Support tile-wise `1x128` scaling in Inductor Triton for FP8 GEMMs, i.e. scaling values along tensors `a` and `b` represent a `1x128` slice of input.

NOTE: Block-wise `128x128` and `1x128` scaling is only supported in CUDA 12.9+; therefore, tile-wise scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run tile-wise scaling (as with deepseek-style scaling).

Test Plan:
Works out-of-the-box with TritonBench:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling-pair=BlockWise1x128,BlockWise1x128 --atol=1e-2 --rtol=0.5
```

Reviewed By: njriasan

Differential Revision: D84025878
@facebook-github-bot facebook-github-bot force-pushed the export-D84025878 branch 2 times, most recently from 48d3c59 to ec744b3 Compare October 27, 2025 20:49
facebook-github-bot pushed a commit that referenced this pull request Oct 27, 2025
…#165132)

Summary:

Support tile-wise `1x128` scaling in Inductor Triton for FP8 GEMMs, i.e. scaling values along tensors `a` and `b` represent a `1x128` slice of input.

NOTE: Block-wise `128x128` and `1x128` scaling is only supported in CUDA 12.9+; therefore, tile-wise scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run tile-wise scaling (as with deepseek-style scaling).

Test Plan:
Works out-of-the-box with TritonBench:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling-pair=BlockWise1x128,BlockWise1x128 --atol=1e-2 --rtol=0.5
```

Reviewed By: njriasan

Differential Revision: D84025878
facebook-github-bot pushed a commit that referenced this pull request Oct 27, 2025
…#165132)

Summary:

Support tile-wise `1x128` scaling in Inductor Triton for FP8 GEMMs, i.e. scaling values along tensors `a` and `b` represent a `1x128` slice of input.

NOTE: Block-wise `128x128` and `1x128` scaling is only supported in CUDA 12.9+; therefore, tile-wise scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run tile-wise scaling (as with deepseek-style scaling).

Test Plan:
Works out-of-the-box with TritonBench:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling-pair=BlockWise1x128,BlockWise1x128 --atol=1e-2 --rtol=0.5
```

Reviewed By: njriasan

Differential Revision: D84025878
…#165132)

Summary:

Support tile-wise `1x128` scaling in Inductor Triton for FP8 GEMMs, i.e. scaling values along tensors `a` and `b` represent a `1x128` slice of input.

NOTE: Block-wise `128x128` and `1x128` scaling is only supported in CUDA 12.9+; therefore, tile-wise scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run tile-wise scaling (as with deepseek-style scaling).

Test Plan:
Works out-of-the-box with TritonBench:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling-pair=BlockWise1x128,BlockWise1x128 --atol=1e-2 --rtol=0.5
```

Reviewed By: njriasan

Differential Revision: D84025878
@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: Limited CI on H100 / linux-jammy-cuda12_8-py3_10-gcc11-sm90-FA3-ABI-stable-test / test

Details for Dev Infra team Raised by workflow job

@jeanschmidt
Copy link
Contributor

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: Limited CI on H100 / linux-jammy-cuda12_8-py3_10-gcc11-sm90-FA3-ABI-stable-test / test

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@jeanschmidt
Copy link
Contributor

@pytorchbot -f "merge with -i is not ignoring a canceled job, as it says it will. Seems a bug"

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 3, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'merge with -i is not ignoring a canceled job, as it says it will. Seems a bug' (choose from 'merge', 'revert', 'rebase', 'label', 'drci', 'cherry-pick')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci,cherry-pick} ...

Try @pytorchbot --help for more info.

@jeanschmidt
Copy link
Contributor

@pytorchbot merge -f "merge with -i is not ignoring a canceled job, as it says it will. Seems a bug"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request Nov 4, 2025
…#165132)

Summary:
Support tile-wise `1x128` scaling in Inductor Triton for FP8 GEMMs, i.e. scaling values along tensors `a` and `b` represent a `1x128` slice of input.

NOTE: Block-wise `128x128` and `1x128` scaling is only supported in CUDA 12.9+; therefore, tile-wise scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run tile-wise scaling (as with deepseek-style scaling).

Test Plan:
Works out-of-the-box with TritonBench:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling-pair=BlockWise1x128,BlockWise1x128 --atol=1e-2 --rtol=0.5
```

Differential Revision: D84025878

Pull Request resolved: #165132
Approved by: https://github.com/eqy, https://github.com/drisspg, https://github.com/njriasan
@github-actions github-actions bot deleted the export-D84025878 branch December 4, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants