Skip to content

Commit 6be8295

Browse files
jerrymannilpytorchmergebot
authored andcommitted
[ROCm] Improve vectorized elementwise kernel performance in MI300X (#153634)
* Use non-temporal loads to improve the vectorized elementwise kernel performance on MI300 * Use thread_work_size of 8 or 16 for vectorized elementwise kernel Co-author: @amd-hhashemi Pull Request resolved: #153634 Approved by: https://github.com/jeffdaily
1 parent 555fc05 commit 6be8295

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,15 @@ C10_LAUNCH_BOUNDS_1(num_threads())
226226
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
227227
using traits = function_traits<func_t>;
228228
constexpr auto io_size = calc_io_size<func_t>();
229-
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
229+
#ifdef __gfx942__
230+
constexpr int tws = (io_size >= 2) ? 8 : 16;
231+
#else
232+
constexpr int tws = elems_per_thread<io_size>();
233+
#endif
234+
constexpr int bws = tws * num_threads();
235+
int remaining = N - bws * blockIdx.x;
230236

231-
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
237+
if (remaining < bws) { // if this block handles the reminder,
232238
// just do a naive unrolled loop
233239
auto input_calc = TrivialOffsetCalculator<traits::arity>();
234240
auto output_calc = TrivialOffsetCalculator<1>();
@@ -240,14 +246,14 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
240246
decltype(output_calc),
241247
memory::LoadWithoutCast,
242248
memory::StoreWithoutCast,
243-
elems_per_thread<io_size>()>(
249+
tws>(
244250
data, remaining, input_calc, output_calc, loader, storer);
245251
elementwise_kernel_helper(f, policy);
246252
} else { // if this block has a full `block_work_size` data to handle, use
247253
// vectorized memory access
248254
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
249255
elementwise_kernel_helper(
250-
f, memory::policies::vectorized<optimal_vec_size, array_t, elems_per_thread<io_size>()>(data));
256+
f, memory::policies::vectorized<optimal_vec_size, array_t, tws>(data));
251257
}
252258
}
253259
#endif // USE_ROCM
@@ -285,10 +291,12 @@ static inline void launch_vectorized_kernel(
285291
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
286292
using traits = function_traits<func_t>;
287293
constexpr auto io_size = calc_io_size<func_t>();
288-
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
289294
auto stream = at::cuda::getCurrentCUDAStream();
290295
#ifdef USE_ROCM
291296
int vec_size = memory::can_vectorize_up_to<func_t>(data);
297+
c10::DeviceIndex curDevice = -1;
298+
AT_CUDA_CHECK(c10::cuda::GetDevice(&curDevice));
299+
int tws = at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice) ? ((io_size >= 2) ? 8 : 16) : elems_per_thread<io_size>();
292300
#else
293301
using cpp_type = typename function_traits<func_t>::result_type;
294302
const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t>(data);
@@ -305,7 +313,10 @@ static inline void launch_vectorized_kernel(
305313
if constexpr (sizeof(cpp_type) < 2) {
306314
vec_size = std::min<uint16_t>(vec_size, 4);
307315
}
316+
int tws = elems_per_thread<io_size>();
308317
#endif
318+
int bws = tws * num_threads();
319+
int64_t grid = (N + bws - 1) / bws;
309320
switch (vec_size) {
310321
#ifdef USE_ROCM
311322
case 16:
@@ -334,8 +345,9 @@ static inline void launch_vectorized_kernel(
334345
auto output_calc = TrivialOffsetCalculator<1>();
335346
auto loader = memory::LoadWithoutCast();
336347
auto storer = memory::StoreWithoutCast();
348+
int64_t grid_unrolled = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
337349
unrolled_elementwise_kernel<func_t, array_t, elems_per_thread<io_size>()>
338-
<<<grid, num_threads(), 0, stream>>>(
350+
<<<grid_unrolled, num_threads(), 0, stream>>>(
339351
N, f, data, input_calc, output_calc, loader, storer);
340352
C10_CUDA_KERNEL_LAUNCH_CHECK();
341353
break;

aten/src/ATen/native/cuda/MemoryAccess.cuh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,30 @@ template <int vec_size, typename scalar_t>
187187
__device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
188188
using vec_t = aligned_vector<scalar_t, vec_size>;
189189
auto *from = reinterpret_cast<const vec_t *>(base_ptr);
190+
#if defined(USE_ROCM) && defined(__gfx942__)
191+
using longx2 = __attribute__((__vector_size__(4*sizeof(int)))) int;
192+
if constexpr (sizeof(vec_t) == sizeof(int)) {
193+
union {
194+
vec_t v;
195+
int i;
196+
} tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast<const int *>(&(from[offset]))) };
197+
return tmpt.v;
198+
}
199+
else if constexpr (sizeof(vec_t) == sizeof(long)) {
200+
union {
201+
vec_t v;
202+
long i;
203+
} tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast<const long *>(&(from[offset]))) };
204+
return tmpt.v;
205+
}
206+
else if constexpr (sizeof(vec_t) == sizeof(longx2)) {
207+
union {
208+
vec_t v;
209+
longx2 i;
210+
} tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast<const longx2 *>(&(from[offset]))) };
211+
return tmpt.v;
212+
}
213+
#endif
190214
return from[offset];
191215
}
192216

0 commit comments

Comments
 (0)