torch.topk: refactor global histogram/cumsum into a dedicated kernel to eliminate redundant memory access#164459
Closed
YyWangCS wants to merge 5 commits intopytorch:mainfrom
Closed
torch.topk: refactor global histogram/cumsum into a dedicated kernel to eliminate redundant memory access#164459YyWangCS wants to merge 5 commits intopytorch:mainfrom
YyWangCS wants to merge 5 commits intopytorch:mainfrom
Conversation
… avoid redundant memory access
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164459
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a5dc805 with merge base 39c340e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Contributor
Author
|
cc @ngimel |
ngimel
approved these changes
Oct 3, 2025
Skylion007
approved these changes
Oct 3, 2025
Contributor
Author
|
@eqy This PR is approved by two reviewers and could you help merge it? |
Collaborator
|
@pytorchbot merge |
Collaborator
|
@YyWangCS You can use @pytorchbot to self-help. |
Collaborator
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Chao1Han
pushed a commit
to Chao1Han/pytorch
that referenced
this pull request
Oct 21, 2025
…to eliminate redundant memory access (pytorch#164459) # TLDR This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs. The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | # Background After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in `torch.topk` on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from **36 ms** to **1.6 s**. Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in [PR pytorch#145536](pytorch#145536). # Analysis `torch.topk` relies on **RadixSelect** to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages: 1. **Local histogram**: Each CUDA block processes a subset of the input and writes its local histogram to global memory. 2. **Global reduction**: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram. Before [PR pytorch#145536](pytorch#145536), both stages ran inside a single kernel (`radixFindKthValues`), using a semaphore to ensure that all local histograms were completed before reduction. In PR pytorch#145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (`computeBlockwiseKthCounts`) to avoid the semaphore. While this simplifies synchronization, it introduces **redundant memory reads**: - `computeBlockwiseKthCounts` launches `numInputSlices * blocks_per_slice` blocks. - For each row (slice), `blocks_per_slice` CUDA blocks redundantly reload the same local histograms from global memory. # This PR To address this inefficiency, we introduce the following optimizations: 1. **Dedicated kernel**: Refactor global histogram and cumsum computation into a separate GPU kernel, `computeDigitCumSum`. 2. **Loop unrolling**: Apply loop unrolling in `computeDigitCumSum` to speed up local histogram reads. # Performance We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | Besides, I have verified the correctness of this PR with different inputs. Pull Request resolved: pytorch#164459 Approved by: https://github.com/ngimel, https://github.com/Skylion007
jerrymannil
pushed a commit
to ROCm/pytorch
that referenced
this pull request
Dec 15, 2025
…to eliminate redundant memory access (pytorch#164459) # TLDR This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs. The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | # Background After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in `torch.topk` on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from **36 ms** to **1.6 s**. Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in [PR pytorch#145536](pytorch#145536). # Analysis `torch.topk` relies on **RadixSelect** to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages: 1. **Local histogram**: Each CUDA block processes a subset of the input and writes its local histogram to global memory. 2. **Global reduction**: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram. Before [PR pytorch#145536](pytorch#145536), both stages ran inside a single kernel (`radixFindKthValues`), using a semaphore to ensure that all local histograms were completed before reduction. In PR pytorch#145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (`computeBlockwiseKthCounts`) to avoid the semaphore. While this simplifies synchronization, it introduces **redundant memory reads**: - `computeBlockwiseKthCounts` launches `numInputSlices * blocks_per_slice` blocks. - For each row (slice), `blocks_per_slice` CUDA blocks redundantly reload the same local histograms from global memory. # This PR To address this inefficiency, we introduce the following optimizations: 1. **Dedicated kernel**: Refactor global histogram and cumsum computation into a separate GPU kernel, `computeDigitCumSum`. 2. **Loop unrolling**: Apply loop unrolling in `computeDigitCumSum` to speed up local histogram reads. # Performance We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | Besides, I have verified the correctness of this PR with different inputs. Pull Request resolved: pytorch#164459 Approved by: https://github.com/ngimel, https://github.com/Skylion007
jerrymannil
pushed a commit
to ROCm/pytorch
that referenced
this pull request
Dec 15, 2025
…to eliminate redundant memory access (pytorch#164459) # TLDR This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs. The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | # Background After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in `torch.topk` on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from **36 ms** to **1.6 s**. Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in [PR pytorch#145536](pytorch#145536). # Analysis `torch.topk` relies on **RadixSelect** to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages: 1. **Local histogram**: Each CUDA block processes a subset of the input and writes its local histogram to global memory. 2. **Global reduction**: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram. Before [PR pytorch#145536](pytorch#145536), both stages ran inside a single kernel (`radixFindKthValues`), using a semaphore to ensure that all local histograms were completed before reduction. In PR pytorch#145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (`computeBlockwiseKthCounts`) to avoid the semaphore. While this simplifies synchronization, it introduces **redundant memory reads**: - `computeBlockwiseKthCounts` launches `numInputSlices * blocks_per_slice` blocks. - For each row (slice), `blocks_per_slice` CUDA blocks redundantly reload the same local histograms from global memory. # This PR To address this inefficiency, we introduce the following optimizations: 1. **Dedicated kernel**: Refactor global histogram and cumsum computation into a separate GPU kernel, `computeDigitCumSum`. 2. **Loop unrolling**: Apply loop unrolling in `computeDigitCumSum` to speed up local histogram reads. # Performance We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs. | Input Shape | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) | | -------------- | --------------- | --------------- | ------------------ | | (1, 1B) | 36.6 | 1564.1 | 25.6 | | (1, 100M) | 3.56 | 17.4 | 2.54 | | (1, 1000,000) | 0.135 | 0.145 | 0.098 | | (512, 128000) | 1.33 | 1.33 | 1.32 | | (8192, 128000) | 19.6 | 19.6 | 19.4 | Besides, I have verified the correctness of this PR with different inputs. Pull Request resolved: pytorch#164459 Approved by: https://github.com/ngimel, https://github.com/Skylion007
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
TLDR
This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs.
The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs.
Background
After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in
torch.topkon NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from 36 ms to 1.6 s.Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in PR #145536.
Analysis
torch.topkrelies on RadixSelect to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages:Before PR #145536, both stages ran inside a single kernel (
radixFindKthValues), using a semaphore to ensure that all local histograms were completed before reduction.In PR #145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (
computeBlockwiseKthCounts) to avoid the semaphore. While this simplifies synchronization, it introduces redundant memory reads:computeBlockwiseKthCountslaunchesnumInputSlices * blocks_per_sliceblocks.blocks_per_sliceCUDA blocks redundantly reload the same local histograms from global memory.This PR
To address this inefficiency, we introduce the following optimizations:
computeDigitCumSum.computeDigitCumSumto speed up local histogram reads.Performance
We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs.
Besides, I have verified the correctness of this PR with different inputs.