Skip to content

torch.topk: refactor global histogram/cumsum into a dedicated kernel to eliminate redundant memory access#164459

Closed
YyWangCS wants to merge 5 commits intopytorch:mainfrom
YyWangCS:YyWangCS/optimize_topk
Closed

torch.topk: refactor global histogram/cumsum into a dedicated kernel to eliminate redundant memory access#164459
YyWangCS wants to merge 5 commits intopytorch:mainfrom
YyWangCS:YyWangCS/optimize_topk

Conversation

@YyWangCS
Copy link
Contributor

@YyWangCS YyWangCS commented Oct 2, 2025

TLDR

This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs.

The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs.

Input Shape torch2.6.0 (ms) torch2.8.0 (ms) 2.8.0+this PR (ms)
(1, 1B) 36.6 1564.1 25.6
(1, 100M) 3.56 17.4 2.54
(1, 1000,000) 0.135 0.145 0.098
(512, 128000) 1.33 1.33 1.32
(8192, 128000) 19.6 19.6 19.4

Background

After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in torch.topk on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from 36 ms to 1.6 s.

Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in PR #145536.

Analysis

torch.topk relies on RadixSelect to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages:

  1. Local histogram: Each CUDA block processes a subset of the input and writes its local histogram to global memory.
  2. Global reduction: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram.

Before PR #145536, both stages ran inside a single kernel (radixFindKthValues), using a semaphore to ensure that all local histograms were completed before reduction.

In PR #145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (computeBlockwiseKthCounts) to avoid the semaphore. While this simplifies synchronization, it introduces redundant memory reads:

  • computeBlockwiseKthCounts launches numInputSlices * blocks_per_slice blocks.
  • For each row (slice), blocks_per_slice CUDA blocks redundantly reload the same local histograms from global memory.

This PR

To address this inefficiency, we introduce the following optimizations:

  1. Dedicated kernel: Refactor global histogram and cumsum computation into a separate GPU kernel, computeDigitCumSum.
  2. Loop unrolling: Apply loop unrolling in computeDigitCumSum to speed up local histogram reads.

Performance

We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs.

Input Shape torch2.6.0 (ms) torch2.8.0 (ms) 2.8.0+this PR (ms)
(1, 1B) 36.6 1564.1 25.6
(1, 100M) 3.56 17.4 2.54
(1, 1000,000) 0.135 0.145 0.098
(512, 128000) 1.33 1.33 1.32
(8192, 128000) 19.6 19.6 19.4

Besides, I have verified the correctness of this PR with different inputs.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164459

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a5dc805 with merge base 39c340e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Oct 2, 2025
@eqy eqy requested a review from zasdfgbnm October 2, 2025 14:58
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 2, 2025
@YyWangCS
Copy link
Contributor Author

YyWangCS commented Oct 3, 2025

cc @ngimel

@YyWangCS
Copy link
Contributor Author

YyWangCS commented Oct 7, 2025

@eqy This PR is approved by two reviewers and could you help merge it?

@cyyever
Copy link
Collaborator

cyyever commented Oct 7, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 7, 2025
@cyyever
Copy link
Collaborator

cyyever commented Oct 7, 2025

@YyWangCS You can use @pytorchbot to self-help.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…to eliminate redundant memory access (pytorch#164459)

# TLDR
This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs.

The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs.
| Input Shape    | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) |
| -------------- | --------------- | --------------- | ------------------ |
| (1, 1B)        | 36.6            | 1564.1          | 25.6               |
| (1, 100M)      | 3.56            | 17.4            | 2.54               |
| (1, 1000,000)  | 0.135           | 0.145           | 0.098              |
| (512, 128000)  | 1.33            | 1.33            | 1.32               |
| (8192, 128000) | 19.6            | 19.6            | 19.4               |

# Background
After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in `torch.topk` on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from **36 ms** to **1.6 s**.

Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in [PR pytorch#145536](pytorch#145536).

# Analysis

`torch.topk` relies on **RadixSelect** to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages:

1. **Local histogram**: Each CUDA block processes a subset of the input and writes its local histogram to global memory.
2. **Global reduction**: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram.

Before [PR pytorch#145536](pytorch#145536), both stages ran inside a single kernel (`radixFindKthValues`), using a semaphore to ensure that all local histograms were completed before reduction.

In PR pytorch#145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (`computeBlockwiseKthCounts`) to avoid the semaphore. While this simplifies synchronization, it introduces **redundant memory reads**:

- `computeBlockwiseKthCounts` launches `numInputSlices * blocks_per_slice` blocks.
- For each row (slice), `blocks_per_slice` CUDA blocks redundantly reload the same local histograms from global memory.

# This PR

To address this inefficiency, we introduce the following optimizations:

1. **Dedicated kernel**: Refactor global histogram and cumsum computation into a separate GPU kernel, `computeDigitCumSum`.
2. **Loop unrolling**: Apply loop unrolling in `computeDigitCumSum` to speed up local histogram reads.

# Performance
We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs.

| Input Shape    | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) |
| -------------- | --------------- | --------------- | ------------------ |
| (1, 1B)        | 36.6            | 1564.1          | 25.6               |
| (1, 100M)      | 3.56            | 17.4            | 2.54               |
| (1, 1000,000)  | 0.135           | 0.145           | 0.098              |
| (512, 128000)  | 1.33            | 1.33            | 1.32               |
| (8192, 128000) | 19.6            | 19.6            | 19.4               |

Besides, I have verified the correctness of this PR with different inputs.
Pull Request resolved: pytorch#164459
Approved by: https://github.com/ngimel, https://github.com/Skylion007
jerrymannil pushed a commit to ROCm/pytorch that referenced this pull request Dec 15, 2025
…to eliminate redundant memory access (pytorch#164459)

# TLDR
This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs.

The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs.
| Input Shape    | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) |
| -------------- | --------------- | --------------- | ------------------ |
| (1, 1B)        | 36.6            | 1564.1          | 25.6               |
| (1, 100M)      | 3.56            | 17.4            | 2.54               |
| (1, 1000,000)  | 0.135           | 0.145           | 0.098              |
| (512, 128000)  | 1.33            | 1.33            | 1.32               |
| (8192, 128000) | 19.6            | 19.6            | 19.4               |

# Background
After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in `torch.topk` on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from **36 ms** to **1.6 s**.

Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in [PR pytorch#145536](pytorch#145536).

# Analysis

`torch.topk` relies on **RadixSelect** to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages:

1. **Local histogram**: Each CUDA block processes a subset of the input and writes its local histogram to global memory.
2. **Global reduction**: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram.

Before [PR pytorch#145536](pytorch#145536), both stages ran inside a single kernel (`radixFindKthValues`), using a semaphore to ensure that all local histograms were completed before reduction.

In PR pytorch#145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (`computeBlockwiseKthCounts`) to avoid the semaphore. While this simplifies synchronization, it introduces **redundant memory reads**:

- `computeBlockwiseKthCounts` launches `numInputSlices * blocks_per_slice` blocks.
- For each row (slice), `blocks_per_slice` CUDA blocks redundantly reload the same local histograms from global memory.

# This PR

To address this inefficiency, we introduce the following optimizations:

1. **Dedicated kernel**: Refactor global histogram and cumsum computation into a separate GPU kernel, `computeDigitCumSum`.
2. **Loop unrolling**: Apply loop unrolling in `computeDigitCumSum` to speed up local histogram reads.

# Performance
We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs.

| Input Shape    | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) |
| -------------- | --------------- | --------------- | ------------------ |
| (1, 1B)        | 36.6            | 1564.1          | 25.6               |
| (1, 100M)      | 3.56            | 17.4            | 2.54               |
| (1, 1000,000)  | 0.135           | 0.145           | 0.098              |
| (512, 128000)  | 1.33            | 1.33            | 1.32               |
| (8192, 128000) | 19.6            | 19.6            | 19.4               |

Besides, I have verified the correctness of this PR with different inputs.
Pull Request resolved: pytorch#164459
Approved by: https://github.com/ngimel, https://github.com/Skylion007
jerrymannil pushed a commit to ROCm/pytorch that referenced this pull request Dec 15, 2025
…to eliminate redundant memory access (pytorch#164459)

# TLDR
This PR removes the regression in torch.topk introduced from torch 2.7.0 and delivers much better performance for large inputs.

The table below reports execution times on H20 for various input sizes with float32 data, extracting the top-100 values. Results indicate that this PR restores and improves performance, especially on large inputs.
| Input Shape    | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) |
| -------------- | --------------- | --------------- | ------------------ |
| (1, 1B)        | 36.6            | 1564.1          | 25.6               |
| (1, 100M)      | 3.56            | 17.4            | 2.54               |
| (1, 1000,000)  | 0.135           | 0.145           | 0.098              |
| (512, 128000)  | 1.33            | 1.33            | 1.32               |
| (8192, 128000) | 19.6            | 19.6            | 19.4               |

# Background
After upgrading PyTorch from 2.6.0 to 2.7.0, we observed a significant GPU performance regression in `torch.topk` on NVIDIA GPUs. For instance, extracting the top-1000 largest values from one billion floats on an NVIDIA H20 increased from **36 ms** to **1.6 s**.

Profiling with Nsight Compute indicates that the slowdown is caused by redundant memory accesses introduced in [PR pytorch#145536](pytorch#145536).

# Analysis

`torch.topk` relies on **RadixSelect** to find the target values. Each radix pass requires computing a histogram of the input values. For large inputs, histogram computation is split into two stages:

1. **Local histogram**: Each CUDA block processes a subset of the input and writes its local histogram to global memory.
2. **Global reduction**: A single CUDA block reads all local histograms from global memory and reduces them into the final global histogram.

Before [PR pytorch#145536](pytorch#145536), both stages ran inside a single kernel (`radixFindKthValues`), using a semaphore to ensure that all local histograms were completed before reduction.

In PR pytorch#145536, the global histogram computation was merged with subsequent top-k calculations into a single kernel (`computeBlockwiseKthCounts`) to avoid the semaphore. While this simplifies synchronization, it introduces **redundant memory reads**:

- `computeBlockwiseKthCounts` launches `numInputSlices * blocks_per_slice` blocks.
- For each row (slice), `blocks_per_slice` CUDA blocks redundantly reload the same local histograms from global memory.

# This PR

To address this inefficiency, we introduce the following optimizations:

1. **Dedicated kernel**: Refactor global histogram and cumsum computation into a separate GPU kernel, `computeDigitCumSum`.
2. **Loop unrolling**: Apply loop unrolling in `computeDigitCumSum` to speed up local histogram reads.

# Performance
We benchmarked torch.topk on NVIDIA H20 with float32 inputs, extracting the top-100 values across different input sizes. The results in the table below demonstrate that this PR effectively eliminates the performance regression introduced in 2.7.0 and delivers substantial improvements on large inputs.

| Input Shape    | torch2.6.0 (ms) | torch2.8.0 (ms) | 2.8.0+this PR (ms) |
| -------------- | --------------- | --------------- | ------------------ |
| (1, 1B)        | 36.6            | 1564.1          | 25.6               |
| (1, 100M)      | 3.56            | 17.4            | 2.54               |
| (1, 1000,000)  | 0.135           | 0.145           | 0.098              |
| (512, 128000)  | 1.33            | 1.33            | 1.32               |
| (8192, 128000) | 19.6            | 19.6            | 19.4               |

Besides, I have verified the correctness of this PR with different inputs.
Pull Request resolved: pytorch#164459
Approved by: https://github.com/ngimel, https://github.com/Skylion007
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: cuda release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants