Skip to content

Update/add to qr_ks_vs_whole_k_prefetch pipeline#3485

Open
qianfengz wants to merge 25 commits intodevelopfrom
whole_k_prefetch_n0loop
Open

Update/add to qr_ks_vs_whole_k_prefetch pipeline#3485
qianfengz wants to merge 25 commits intodevelopfrom
whole_k_prefetch_n0loop

Conversation

@qianfengz
Copy link
Contributor

@qianfengz qianfengz commented Dec 24, 2025

About qr_ks_vs_whole_k_prefetch pipeline

The pipeline qr_ks_vs_whole_k_prefetch is mainly used for the situations where total number of work-groups is not enough to occupy the CUs. When the total number of work-groups is low, use MTile size (kM0) 64 rather than 128 can improve the CU occupancy. And with kM0=64, less registers are consumed to save P and O, thus enough vgprs are left for prefetch the whole k_tile from next iteration in the main-loop, and thus performance can be improved compared to the usual method of using kM0=128,
Except for prefetching whole k tile when kM0=64, the pipeline also has the path to use kM0=128, in which case, 1/2 of n0_loops slices of the k tile are prefetched for next iteration. Path of kM0=128 can be used as a replacement of using pipeline qr_ks_vs_async

What this PR does

  1. Update in the pipeline policy to ensure best mfma instructions are used on MI350
  2. Add the qr_ks_vs_whole_k_prefetch_trload pipeline instance so that V can be loaded using transposed loading on MI350 (avoid the need of lots of shuffling instructions)
  3. Using n0_loop to implement Gemm0 instead of the commonly used k0_loop. n0_loop brings the benefits of less move_tile_window() call, and removing the need of clear_tile(s_acc) in the main loop.
  4. Complete support of naive tile loading for hdim96 and hdim160, which means loading tile of hdim96/hdim160 without having to pad them to hdim128/hdim256
  5. Other fine-grained improvement (eg. use explict partition_index to guarantee warp_id is allocated on vgpr for store_tile/load_tile to/from LDS tile_window)

Performance results

  1. For attention shapes which leads to kM0=64, qr_ks_vs_async_whole_k_prefetch_trload shows much better performance than qr_ks_vs_async_trload on the same case (execution time 41.02ms by whole_k_prefetch_trload & 58.50ms by async_load)
  2. For attention shapes which leads to kM0=128, qr_ks_vs_async_whole_k_prefetch_trload show a little bit better performance than qr_ks_vs_async on mi350 (execution time 104.50ms by whole_k_prefetch_trload & 106.50ms by qr_ks_vs_async). And they shows completely on-par performance on MI300

Test/Verify

  1. Use the ROCM xformers branch test_whole_k_prefetch_n0loop to test/verify qr_ks_vs_whole_k_prefetch pipeline since this pipeline can not be used by ck_tile fmha example so far
  2. Use the following command-line for building/testing xformers
#> git clone -b test_whole_k_prefetch_n0loop https://github.com/ROCm/xformers
#> git submodule update --init --recursive   
#> pip  install --no-build-isolation -e ./
#> pytest tests/test_mem_eff_attention.py::test_forward
  1. Any scripts which can run on xformers can be used to evaluate qr_ks_vs_whole_k_prefetch pipeline. Using the two environ variable to switch from using different pipelines
#> export FMHA_DISABLE_SPECIAL_TREATMENT=1              #> to disable using FAV3 and qr_ks_vs_async_trload pipeline
#> export FMHA_DISABLE_ASYNC_PIPELINE=1                     #>  to disable using qr_ks_vs_async pipeline

Discussion

… next iteration in the non-whole-k-perfetch path
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates and enhances the qr_ks_vs_whole_k_prefetch pipeline to improve performance on MI350 GPUs through better MFMA instruction usage, transposed V-loading support, and N0-loop implementation. The pipeline targets scenarios where work-group counts are low, enabling better CU occupancy by using smaller MTile sizes (kM0=64 vs 128) while prefetching entire K tiles.

Changes:

  • Adds transposed V-loading support (qr_ks_vs_whole_k_prefetch_trload) to reduce shuffle instructions on MI350
  • Implements N0-loop based Gemm0 to reduce tile window movement overhead and eliminate clear_tile calls
  • Adds full support for hdim96/hdim160 without padding requirements
  • Updates MFMA instruction selection to ensure optimal choices for MI350

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp New GEMM block implementation supporting transposed V-loading with N-dimension prefetching
block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp N-dimension prefetching GEMM implementation for standard (non-transposed) loading
block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp K-dimension prefetching GEMM implementation
tile_fmha_shape.hpp Adds kN0Sub field and relaxes static assertion for N0-loop support
block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp New pipeline variant with transposed V-loading
block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp Comprehensive policy updates for LDS management, alignment, and MFMA selection
block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp Core pipeline updated with N0-loop implementation and simplified memory management
block_fmha_pipeline_problem.hpp Adds utility functions for calculating optimal vector sizes
fmha_fwd_kernel.hpp Kernel updates to support N0-loop pipelines and naive hdim loading
fmha.hpp Includes new trload pipeline header

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

load_tile(b_warp_windows(number<nIter + 1>{})(kIter));
};

__builtin_amdgcn_sched_barrier(0x0000001);
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scheduling barrier mask has an incorrect value. Line 140 uses 0x0000001 (7 digits) while line 127 correctly uses 0x00000001 (8 digits). This should be 0x00000001 to match the proper 32-bit mask format.

Suggested change
__builtin_amdgcn_sched_barrier(0x0000001);
__builtin_amdgcn_sched_barrier(0x00000001);

Copilot uses AI. Check for mistakes.
}

const auto bias_tile = load_tile(bias_dram_window); // load bias tile
__builtin_amdgcn_sched_barrier(0x000000001);
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scheduling barrier mask has an incorrect value. This uses 0x000000001 (9 digits) when it should be 0x00000001 (8 digits) to match the proper 32-bit hexadecimal format.

Suggested change
__builtin_amdgcn_sched_barrier(0x000000001);
__builtin_amdgcn_sched_barrier(0x00000001);

Copilot uses AI. Check for mistakes.
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size for dividing kN0
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable kN0Sub is assigned from BlockTile index 2, which is the same index as kK0 (line 52). This appears to be intentional based on the assertion at line 59, but the naming is confusing since kN0Sub suggests it's related to kN0, not kK0. Consider renaming to better reflect its relationship to both dimensions, or add a clarifying comment.

Suggested change
static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size for dividing kN0
static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // same index as kK0; used as subdivision factor when dividing kN0

Copilot uses AI. Check for mistakes.
{
if(num_total_loop <= 0)
// assuming no random values need be saved, this is true when the pipeline is called from
// xformers, since we have a separate kernel to generated randomm values
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'generated' from 'generated randomm values' to 'generate random values'.

Suggested change
// xformers, since we have a separate kernel to generated randomm values
// xformers, since we have a separate kernel to generate random values

Copilot uses AI. Check for mistakes.
template <typename T>
static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag<T>::value;

// A helper struct for detechting kUseTrLoad
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'detechting' to 'detecting'.

Suggested change
// A helper struct for detechting kUseTrLoad
// A helper struct for detecting kUseTrLoad

Copilot uses AI. Check for mistakes.
Comment on lines +40 to +41
else
static_assert(false, "The data type is not supported!");
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using static_assert(false, ...) directly can cause compilation issues with some compilers even when the branch is not taken. Consider using a type-dependent false condition like static_assert(sizeof(DataType) == 0, ...) or static_assert(!std::is_same_v<DataType, DataType>, ...).

Copilot uses AI. Check for mistakes.
@asleepzzz
Copy link
Contributor

we found async can beat wholek with a new config, will discuss with qianfeng

@ammallya
Copy link
Contributor

ammallya commented Feb 3, 2026

Error importing due to merge conflicts – please reopen the PR on ROCm/rocm-libraries

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants