Update/add to qr_ks_vs_whole_k_prefetch pipeline#3485
Update/add to qr_ks_vs_whole_k_prefetch pipeline#3485
Conversation
…oping Gemm0 along n0 dimension
…e_k_prefetch pipeline
…n whole_k_prefetch path)
…n whole_k_prefetch path in trload pipeline)
… next iteration in the non-whole-k-perfetch path
There was a problem hiding this comment.
Pull request overview
This PR updates and enhances the qr_ks_vs_whole_k_prefetch pipeline to improve performance on MI350 GPUs through better MFMA instruction usage, transposed V-loading support, and N0-loop implementation. The pipeline targets scenarios where work-group counts are low, enabling better CU occupancy by using smaller MTile sizes (kM0=64 vs 128) while prefetching entire K tiles.
Changes:
- Adds transposed V-loading support (
qr_ks_vs_whole_k_prefetch_trload) to reduce shuffle instructions on MI350 - Implements N0-loop based Gemm0 to reduce tile window movement overhead and eliminate clear_tile calls
- Adds full support for hdim96/hdim160 without padding requirements
- Updates MFMA instruction selection to ensure optimal choices for MI350
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp | New GEMM block implementation supporting transposed V-loading with N-dimension prefetching |
| block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp | N-dimension prefetching GEMM implementation for standard (non-transposed) loading |
| block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp | K-dimension prefetching GEMM implementation |
| tile_fmha_shape.hpp | Adds kN0Sub field and relaxes static assertion for N0-loop support |
| block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp | New pipeline variant with transposed V-loading |
| block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp | Comprehensive policy updates for LDS management, alignment, and MFMA selection |
| block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | Core pipeline updated with N0-loop implementation and simplified memory management |
| block_fmha_pipeline_problem.hpp | Adds utility functions for calculating optimal vector sizes |
| fmha_fwd_kernel.hpp | Kernel updates to support N0-loop pipelines and naive hdim loading |
| fmha.hpp | Includes new trload pipeline header |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| load_tile(b_warp_windows(number<nIter + 1>{})(kIter)); | ||
| }; | ||
|
|
||
| __builtin_amdgcn_sched_barrier(0x0000001); |
There was a problem hiding this comment.
The scheduling barrier mask has an incorrect value. Line 140 uses 0x0000001 (7 digits) while line 127 correctly uses 0x00000001 (8 digits). This should be 0x00000001 to match the proper 32-bit mask format.
| __builtin_amdgcn_sched_barrier(0x0000001); | |
| __builtin_amdgcn_sched_barrier(0x00000001); |
| } | ||
|
|
||
| const auto bias_tile = load_tile(bias_dram_window); // load bias tile | ||
| __builtin_amdgcn_sched_barrier(0x000000001); |
There was a problem hiding this comment.
The scheduling barrier mask has an incorrect value. This uses 0x000000001 (9 digits) when it should be 0x00000001 (8 digits) to match the proper 32-bit hexadecimal format.
| __builtin_amdgcn_sched_barrier(0x000000001); | |
| __builtin_amdgcn_sched_barrier(0x00000001); |
| static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen | ||
| static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen | ||
| static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll | ||
| static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size for dividing kN0 |
There was a problem hiding this comment.
The variable kN0Sub is assigned from BlockTile index 2, which is the same index as kK0 (line 52). This appears to be intentional based on the assertion at line 59, but the naming is confusing since kN0Sub suggests it's related to kN0, not kK0. Consider renaming to better reflect its relationship to both dimensions, or add a clarifying comment.
| static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size for dividing kN0 | |
| static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // same index as kK0; used as subdivision factor when dividing kN0 |
| { | ||
| if(num_total_loop <= 0) | ||
| // assuming no random values need be saved, this is true when the pipeline is called from | ||
| // xformers, since we have a separate kernel to generated randomm values |
There was a problem hiding this comment.
Corrected spelling of 'generated' from 'generated randomm values' to 'generate random values'.
| // xformers, since we have a separate kernel to generated randomm values | |
| // xformers, since we have a separate kernel to generate random values |
| template <typename T> | ||
| static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag<T>::value; | ||
|
|
||
| // A helper struct for detechting kUseTrLoad |
There was a problem hiding this comment.
Corrected spelling of 'detechting' to 'detecting'.
| // A helper struct for detechting kUseTrLoad | |
| // A helper struct for detecting kUseTrLoad |
| else | ||
| static_assert(false, "The data type is not supported!"); |
There was a problem hiding this comment.
Using static_assert(false, ...) directly can cause compilation issues with some compilers even when the branch is not taken. Consider using a type-dependent false condition like static_assert(sizeof(DataType) == 0, ...) or static_assert(!std::is_same_v<DataType, DataType>, ...).
|
we found async can beat wholek with a new config, will discuss with qianfeng |
|
Error importing due to merge conflicts – please reopen the PR on ROCm/rocm-libraries |
About qr_ks_vs_whole_k_prefetch pipeline
The pipeline qr_ks_vs_whole_k_prefetch is mainly used for the situations where total number of work-groups is not enough to occupy the CUs. When the total number of work-groups is low, use MTile size (
kM0) 64 rather than 128 can improve the CU occupancy. And withkM0=64, less registers are consumed to save P and O, thus enough vgprs are left for prefetch the whole k_tile from next iteration in the main-loop, and thus performance can be improved compared to the usual method of usingkM0=128,Except for prefetching whole k tile when
kM0=64, the pipeline also has the path to usekM0=128, in which case, 1/2 ofn0_loopsslices of the k tile are prefetched for next iteration. Path ofkM0=128can be used as a replacement of using pipelineqr_ks_vs_asyncWhat this PR does
n0_loopbrings the benefits of lessmove_tile_window()call, and removing the need ofclear_tile(s_acc)in the main loop.partition_indexto guaranteewarp_idis allocated on vgpr for store_tile/load_tile to/from LDS tile_window)Performance results
qr_ks_vs_async_whole_k_prefetch_trloadshows much better performance thanqr_ks_vs_async_trloadon the same case (execution time41.02msby whole_k_prefetch_trload &58.50msby async_load)qr_ks_vs_async_whole_k_prefetch_trloadshow a little bit better performance thanqr_ks_vs_asyncon mi350 (execution time104.50msby whole_k_prefetch_trload &106.50msby qr_ks_vs_async). And they shows completely on-par performance on MI300Test/Verify
test_whole_k_prefetch_n0loopto test/verify qr_ks_vs_whole_k_prefetch pipeline since this pipeline can not be used by ck_tile fmha example so farDiscussion