Skip to content

Add env DeviceTopK without temp_storage args#8982

Open
gonidelis wants to merge 7 commits into
NVIDIA:mainfrom
gonidelis:env_topk
Open

Add env DeviceTopK without temp_storage args#8982
gonidelis wants to merge 7 commits into
NVIDIA:mainfrom
gonidelis:env_topk

Conversation

@gonidelis
Copy link
Copy Markdown
Member

fixes #8969

@copy-pr-bot
Copy link
Copy Markdown
Contributor

copy-pr-bot Bot commented May 14, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@cccl-authenticator-app cccl-authenticator-app Bot moved this from Todo to In Progress in CCCL May 14, 2026
@gonidelis gonidelis marked this pull request as ready for review May 14, 2026 05:56
@gonidelis gonidelis requested a review from a team as a code owner May 14, 2026 05:56
@cccl-authenticator-app cccl-authenticator-app Bot moved this from In Progress to In Review in CCCL May 14, 2026
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 14, 2026

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b28511d4-6cda-4542-8e9d-4bb0ca22f9a8

📥 Commits

Reviewing files that changed from the base of the PR and between 776655d and b0a8519.

📒 Files selected for processing (1)
  • cub/cub/device/device_topk.cuh
🚧 Files skipped from review as they are similar to previous changes (1)
  • cub/cub/device/device_topk.cuh

📝 Walkthrough

Summary by CodeRabbit

  • New Features

    • Added environment-based DeviceTopK overloads for Max/Min Keys and Pairs that accept an execution environment (stream, determinism, ordering) and manage temporary storage internally.
  • Tests

    • Added unit tests validating env-based TopK for built-in and custom key types (with decomposer), including stream-based execution, host-side sorting verification, and edge cases (e.g., K == N).

important:

Walkthrough

Added environment-dispatched DeviceTopK overloads (Max/Min Keys and Pairs, decomposer and non-decomposer variants) that allocate/manage temporary storage via the execution environment; added Catch2 tests validating these env-based overloads with built-in and custom types.

Changes

Environment-Based TopK API

Layer / File(s) Summary
Environment dispatch integration
cub/cub/device/device_topk.cuh
Added #include for cub/detail/env_dispatch.cuh to enable env-dispatched temporary-storage allocation.
MaxPairs environment overloads
cub/cub/device/device_topk.cuh
Added two env-based DeviceTopK::MaxPairs overloads (non-decomposer and decomposer, marked [[nodiscard]]) that allocate temp storage via detail::dispatch_with_env and call detail::dispatch_topk with topk::select::max (use identity_decomposer_t{} for non-decomposer).
MinPairs environment overloads
cub/cub/device/device_topk.cuh
Added two env-based DeviceTopK::MinPairs overloads (non-decomposer and decomposer, marked [[nodiscard]]) that allocate temp storage via detail::dispatch_with_env and call detail::dispatch_topk with topk::select::min.
MaxKeys environment overloads
cub/cub/device/device_topk.cuh
Added two env-based DeviceTopK::MaxKeys overloads (non-decomposer and decomposer, marked [[nodiscard]]) that omit explicit temp-storage parameters, use detail::dispatch_with_env, and dispatch to detail::dispatch_topk with topk::select::max (values passed as NullType*).
MinKeys environment overloads
cub/cub/device/device_topk.cuh
Added two env-based DeviceTopK::MinKeys overloads (non-decomposer and decomposer, marked [[nodiscard]]) that omit explicit temp-storage parameters, use detail::dispatch_with_env, and dispatch to detail::dispatch_topk with topk::select::min.
Environment-based API test coverage
cub/test/catch2_test_device_topk_env_api.cu
New test TU defining topk_custom_t and topk_custom_decomposer_t, plus six Catch2 tests exercising Max/Min Keys and Pairs using an env containing cuda::stream_ref for built-in and decomposer inputs; synchronizes stream and validates results (sorting where output order is unspecified).
Existing test suite expansion
cub/test/catch2_test_device_topk_env.cu
Added thrust/host_vector and <algorithm>, a sorted_top_k helper, multiple env-based tests comparing device outputs to host-sorted expectations for various key types, and an edge-case test where K == num_items.

Assessment against linked issues

Objective Addressed Explanation
Add env-overloads for MaxPairs, MinPairs, MaxKeys, MinKeys that handle memory allocation using memory resource from environment [#8969]
Eliminate requirement for caller-managed d_temp_storage and temp_storage_bytes parameters in environment-based API [#8969]
Provide test coverage for new environment-based overloads with standard and custom types [#8969]

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
cub/test/catch2_test_device_topk_env_api.cu (1)

4-4: 💤 Low value

suggestion: The include insert_nested_NVTX_range_guard.h is not used in this file. Consider removing it.


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 0a4f69d0-88d3-41d1-bc72-f0a079b4fc3e

📥 Commits

Reviewing files that changed from the base of the PR and between 5d79bc2 and 1cd8c3d.

📒 Files selected for processing (3)
  • cub/cub/device/device_topk.cuh
  • cub/test/catch2_test_device_topk_api.cu
  • cub/test/catch2_test_device_topk_env_api.cu

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 1c53e6de-7677-45e1-9484-0d81da4be3ae

📥 Commits

Reviewing files that changed from the base of the PR and between 1cd8c3d and 1637e52.

📒 Files selected for processing (2)
  • cub/cub/device/device_topk.cuh
  • cub/test/catch2_test_device_topk_env.cu
🚧 Files skipped from review as they are similar to previous changes (1)
  • cub/cub/device/device_topk.cuh

Comment thread cub/test/catch2_test_device_topk_env.cu
@github-actions

This comment has been minimized.

@github-actions
Copy link
Copy Markdown
Contributor

🥳 CI Workflow Results

🟩 Finished in 3h 05m: Pass: 100%/283 | Total: 3d 09h | Max: 47m 02s | Hits: 90%/214854

See results here.

Comment on lines +638 to +640
[[nodiscard]] CUB_RUNTIME_FUNCTION static //
::cuda::std::enable_if_t<detail::radix::is_valid_decomposer<detail::it_value_t<KeyInputIteratorT>, DecomposerT>,
cudaError_t>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Please put requirements (enable_if) into the template head instead of using the return type.

Comment on lines +1117 to +1119
[[nodiscard]] CUB_RUNTIME_FUNCTION static //
::cuda::std::enable_if_t<detail::radix::is_valid_decomposer<detail::it_value_t<KeyInputIteratorT>, DecomposerT>,
cudaError_t>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Comment on lines +216 to +221
thrust::host_vector<T> h_in(num_items);
for (int i = 0; i < num_items; ++i)
{
h_in[i] = static_cast<T>((i * 1664525 + 1013904223) % 251);
}
thrust::device_vector<T> d_in = h_in;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important: in ordinary unit tests, please use c2h::host_vector etc.

Comment on lines +209 to +211
using topk_element_types = c2h::type_list<int8_t, int16_t, int32_t, uint32_t, int64_t, float, double>;

C2H_TEST("DeviceTopK::MaxKeys env-alloc returns correct top K", "[topk][env]", topk_element_types)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Why do we need to cover different element types when just testing whether the env-overload works correctly? We should not test the topk implementation itself, those tests already exist elsewhere. The current state adds a lot of compile-time for little gain IMO.

Comment on lines +58 to +63
auto error = cub::DeviceTopK::MaxKeys(d_in.begin(), d_out.begin(), static_cast<int>(d_in.size()), k, env);
if (error != cudaSuccess)
{
std::cerr << "cub::DeviceTopK::MaxKeys failed with status: " << error << '\n';
}
// example-end topk-max-keys-env
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important: we should list at least one possible expected outcome in the documentation. Like:

Suggested change
auto error = cub::DeviceTopK::MaxKeys(d_in.begin(), d_out.begin(), static_cast<int>(d_in.size()), k, env);
if (error != cudaSuccess)
{
std::cerr << "cub::DeviceTopK::MaxKeys failed with status: " << error << '\n';
}
// example-end topk-max-keys-env
auto error = cub::DeviceTopK::MaxKeys(d_in.begin(), d_out.begin(), static_cast<int>(d_in.size()), k, env);
if (error != cudaSuccess)
{
std::cerr << "cub::DeviceTopK::MaxKeys failed with status: " << error << '\n';
}
thrust::device_vector<int> expected{9, 8, 7}; // possibly in different order
// example-end topk-max-keys-env

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Review

Development

Successfully merging this pull request may close these issues.

cub::DeviceTopK does not have env-overloads handling memory allocation

2 participants