[pytorch] Make clamp kernel branchless#167889
[pytorch] Make clamp kernel branchless#167889stashuk-olek wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167889
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9e98273 with merge base 6461548 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@stashuk-olek has exported this pull request. If you are a Meta employee, you can view the originating Diff in D86561069. |
| return ::min(::max(v, lower), upper); | ||
| } | ||
| scalar_t result = ::min(::max(v, lower), upper); | ||
|
|
There was a problem hiding this comment.
The comment suggests the manual propagation is only necessary for RoCM is that the case? If so can the manual NaN propagation be if def away?
Summary: In the old clamp implementation, there was a control divergence with the if statements. This PR reduces the branching for kernel. See attached bench example with source code in linked stack (for correctness + tests on tensor). Branchless implementation shows consistent performance gains across all data types and NaN ratios ======================================== ``` Test Plan: ``` nvcc -O3 -o clamp_bench_comprehensive clamp_benchmark_standalone.cu && ./clamp_bench_comprehensive 2>&1 ``` On H100 ``` [sashko@devgpu010.eag3 /data/users/sashko/fbsource/fbcode/caffe2/aten/src/ATen/native/cuda/benchmarks (clean_clamp)]$ ./clamp_bench_comprehensive ======================================== CUDA Clamp Kernel Comprehensive Benchmark ======================================== Elements: 16777216 (64.0 MB per tensor) Performance Results: ------------------------------------------------------------ Float32 0% NaN : Orig= 241± 10 ns Branch= 231± 0 ns Speedup= +3.9% Float32 1% NaN : Orig= 238± 5 ns Branch= 231± 1 ns Speedup= +3.1% Float32 10% NaN : Orig= 223± 0 ns Branch= 215± 1 ns Speedup= +3.2% Float16 0% NaN : Orig= 219± 3 ns Branch= 214± 1 ns Speedup= +2.2% Float16 1% NaN : Orig= 225± 2 ns Branch= 220± 0 ns Speedup= +2.2% Float16 10% NaN : Orig= 304± 89 ns Branch= 217± 4 ns Speedup=+28.5% BFloat16 0% NaN : Orig= 216± 1 ns Branch= 212± 3 ns Speedup= +1.6% BFloat16 1% NaN : Orig= 216± 1 ns Branch= 212± 0 ns Speedup= +1.9% BFloat16 10% NaN : Orig= 217± 2 ns Branch= 211± 0 ns Speedup= +2.6% ``` On B200 ``` [sashko@devgpu006.snb3 /data/users/sashko/fbsource/fbcode/caffe2/aten/src/ATen/native/cuda/benchmarks (clamp)]$ ./clamp_bench_comprehensive ======================================== CUDA Clamp Kernel Comprehensive Benchmark ======================================== Elements: 16777216 (64.0 MB per tensor) Performance Results: ------------------------------------------------------------ Float32 0% NaN : Orig=104331± 53 ns Branch= 59445± 17 ns Speedup=+43.0% Float32 1% NaN : Orig=104520± 13 ns Branch= 59439± 13 ns Speedup=+43.1% Float32 10% NaN : Orig=104493± 17 ns Branch= 59440± 6 ns Speedup=+43.1% Float16 0% NaN : Orig= 98249± 16 ns Branch= 53278± 16 ns Speedup=+45.8% Float16 1% NaN : Orig= 98313± 9 ns Branch= 53287± 13 ns Speedup=+45.8% Float16 10% NaN : Orig= 98335± 9 ns Branch= 53287± 19 ns Speedup=+45.8% BFloat16 0% NaN : Orig= 98492± 47 ns Branch= 55312± 8 ns Speedup=+43.8% BFloat16 1% NaN : Orig= 99783± 69 ns Branch= 55321± 9 ns Speedup=+44.6% BFloat16 10% NaN : Orig=100284± 37 ns Branch= 55329± 26 ns Speedup=+44.8% ======================================== Differential Revision: D86561069
8da8df4 to
9e98273
Compare
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: In the old clamp implementation, there was a control divergence with the if statements. This PR reduces the branching for kernel. See attached bench example with source code in linked stack (for correctness + tests on tensor). Branchless implementation shows consistent performance gains across all data types and NaN ratios ======================================== ``` Test Plan: ``` nvcc -O3 -o clamp_bench_comprehensive clamp_benchmark_standalone.cu && ./clamp_bench_comprehensive 2>&1 ``` On H100 ``` ./clamp_bench_comprehensive ======================================== CUDA Clamp Kernel Comprehensive Benchmark ======================================== Elements: 16777216 (64.0 MB per tensor) Performance Results: ------------------------------------------------------------ Float32 0% NaN : Orig= 241± 10 ns Branch= 231± 0 ns Speedup= +3.9% Float32 1% NaN : Orig= 238± 5 ns Branch= 231± 1 ns Speedup= +3.1% Float32 10% NaN : Orig= 223± 0 ns Branch= 215± 1 ns Speedup= +3.2% Float16 0% NaN : Orig= 219± 3 ns Branch= 214± 1 ns Speedup= +2.2% Float16 1% NaN : Orig= 225± 2 ns Branch= 220± 0 ns Speedup= +2.2% Float16 10% NaN : Orig= 304± 89 ns Branch= 217± 4 ns Speedup=+28.5% BFloat16 0% NaN : Orig= 216± 1 ns Branch= 212± 3 ns Speedup= +1.6% BFloat16 1% NaN : Orig= 216± 1 ns Branch= 212± 0 ns Speedup= +1.9% BFloat16 10% NaN : Orig= 217± 2 ns Branch= 211± 0 ns Speedup= +2.6% ``` On B200 ``` ./clamp_bench_comprehensive ======================================== CUDA Clamp Kernel Comprehensive Benchmark ======================================== Elements: 16777216 (64.0 MB per tensor) Performance Results: ------------------------------------------------------------ Float32 0% NaN : Orig=104331± 53 ns Branch= 59445± 17 ns Speedup=+43.0% Float32 1% NaN : Orig=104520± 13 ns Branch= 59439± 13 ns Speedup=+43.1% Float32 10% NaN : Orig=104493± 17 ns Branch= 59440± 6 ns Speedup=+43.1% Float16 0% NaN : Orig= 98249± 16 ns Branch= 53278± 16 ns Speedup=+45.8% Float16 1% NaN : Orig= 98313± 9 ns Branch= 53287± 13 ns Speedup=+45.8% Float16 10% NaN : Orig= 98335± 9 ns Branch= 53287± 19 ns Speedup=+45.8% BFloat16 0% NaN : Orig= 98492± 47 ns Branch= 55312± 8 ns Speedup=+43.8% BFloat16 1% NaN : Orig= 99783± 69 ns Branch= 55321± 9 ns Speedup=+44.6% BFloat16 10% NaN : Orig=100284± 37 ns Branch= 55329± 26 ns Speedup=+44.8% ======================================== Differential Revision: D86561069 Pull Request resolved: #167889 Approved by: https://github.com/Skylion007
Summary:
In the old clamp implementation, there was a control divergence with the if statements. This PR reduces the branching for kernel. See attached bench example with source code in linked stack (for correctness + tests on tensor).
Branchless implementation shows consistent
performance gains across all data types and NaN ratios
nvcc -O3 -o clamp_bench_comprehensive clamp_benchmark_standalone.cu && ./clamp_bench_comprehensive 2>&1
./clamp_bench_comprehensive
CUDA Clamp Kernel Comprehensive Benchmark
Elements: 16777216 (64.0 MB per tensor)
Performance Results:
Float32 0% NaN : Orig= 241± 10 ns Branch= 231± 0 ns Speedup= +3.9%
Float32 1% NaN : Orig= 238± 5 ns Branch= 231± 1 ns Speedup= +3.1%
Float32 10% NaN : Orig= 223± 0 ns Branch= 215± 1 ns Speedup= +3.2%
Float16 0% NaN : Orig= 219± 3 ns Branch= 214± 1 ns Speedup= +2.2%
Float16 1% NaN : Orig= 225± 2 ns Branch= 220± 0 ns Speedup= +2.2%
Float16 10% NaN : Orig= 304± 89 ns Branch= 217± 4 ns Speedup=+28.5%
BFloat16 0% NaN : Orig= 216± 1 ns Branch= 212± 3 ns Speedup= +1.6%
BFloat16 1% NaN : Orig= 216± 1 ns Branch= 212± 0 ns Speedup= +1.9%
BFloat16 10% NaN : Orig= 217± 2 ns Branch= 211± 0 ns Speedup= +2.6%
./clamp_bench_comprehensive
CUDA Clamp Kernel Comprehensive Benchmark
Elements: 16777216 (64.0 MB per tensor)
Performance Results:
Float32 0% NaN : Orig=104331± 53 ns Branch= 59445± 17 ns Speedup=+43.0%
Float32 1% NaN : Orig=104520± 13 ns Branch= 59439± 13 ns Speedup=+43.1%
Float32 10% NaN : Orig=104493± 17 ns Branch= 59440± 6 ns Speedup=+43.1%
Float16 0% NaN : Orig= 98249± 16 ns Branch= 53278± 16 ns Speedup=+45.8%
Float16 1% NaN : Orig= 98313± 9 ns Branch= 53287± 13 ns Speedup=+45.8%
Float16 10% NaN : Orig= 98335± 9 ns Branch= 53287± 19 ns Speedup=+45.8%
BFloat16 0% NaN : Orig= 98492± 47 ns Branch= 55312± 8 ns Speedup=+43.8%
BFloat16 1% NaN : Orig= 99783± 69 ns Branch= 55321± 9 ns Speedup=+44.6%
BFloat16 10% NaN : Orig=100284± 37 ns Branch= 55329± 26 ns Speedup=+44.8%
========================================
Differential Revision: D86561069