Skip to content

Commit 72337bd

Browse files
Aidyn-Apytorchmergebot
authored andcommitted
[ATen][CUDA] Optimize 128 bit vectorization (#148320)
Fixes #147376. As per request: #145746 (review) This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size. Pull Request resolved: #148320 Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman
1 parent 3baa85c commit 72337bd

File tree

1 file changed

+69
-4
lines changed

1 file changed

+69
-4
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,69 @@ constexpr auto calc_io_size(){
158158
#endif
159159
}
160160

161+
#ifndef USE_ROCM
162+
// To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel
163+
// into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be
164+
// used on sm_90 and sm_100 exclusively.
165+
template <int vec_size, typename func_t, typename array_t>
166+
C10_LAUNCH_BOUNDS_1(num_threads())
167+
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
168+
if constexpr (vec_size == 8) {
169+
#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
170+
using traits = function_traits<func_t>;
171+
constexpr auto io_size = calc_io_size<func_t>();
172+
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
173+
174+
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
175+
// just do a naive unrolled loop
176+
auto input_calc = TrivialOffsetCalculator<traits::arity>();
177+
auto output_calc = TrivialOffsetCalculator<1>();
178+
auto loader = memory::LoadWithoutCast();
179+
auto storer = memory::StoreWithoutCast();
180+
auto policy = memory::policies::unroll<
181+
array_t,
182+
decltype(input_calc),
183+
decltype(output_calc),
184+
memory::LoadWithoutCast,
185+
memory::StoreWithoutCast,
186+
elems_per_thread<io_size>()>(
187+
data, remaining, input_calc, output_calc, loader, storer);
188+
elementwise_kernel_helper(f, policy);
189+
} else { // if this block has a full `block_work_size` data to handle, use
190+
// vectorized memory access
191+
elementwise_kernel_helper(
192+
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
193+
}
194+
#endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
195+
} else {
196+
using traits = function_traits<func_t>;
197+
constexpr auto io_size = calc_io_size<func_t>();
198+
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
199+
200+
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
201+
// just do a naive unrolled loop
202+
auto input_calc = TrivialOffsetCalculator<traits::arity>();
203+
auto output_calc = TrivialOffsetCalculator<1>();
204+
auto loader = memory::LoadWithoutCast();
205+
auto storer = memory::StoreWithoutCast();
206+
auto policy = memory::policies::unroll<
207+
array_t,
208+
decltype(input_calc),
209+
decltype(output_calc),
210+
memory::LoadWithoutCast,
211+
memory::StoreWithoutCast,
212+
elems_per_thread<io_size>()>(
213+
data, remaining, input_calc, output_calc, loader, storer);
214+
elementwise_kernel_helper(f, policy);
215+
} else { // if this block has a full `block_work_size` data to handle, use
216+
// vectorized memory access
217+
elementwise_kernel_helper(
218+
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
219+
}
220+
}
221+
}
222+
223+
#else // USE_ROCM
161224
template <int vec_size, typename func_t, typename array_t>
162225
C10_LAUNCH_BOUNDS_1(num_threads())
163226
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
@@ -182,15 +245,12 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
182245
elementwise_kernel_helper(f, policy);
183246
} else { // if this block has a full `block_work_size` data to handle, use
184247
// vectorized memory access
185-
#ifdef USE_ROCM
186248
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
187-
#else
188-
constexpr auto optimal_vec_size = vec_size;
189-
#endif
190249
elementwise_kernel_helper(
191250
f, memory::policies::vectorized<optimal_vec_size, array_t, elems_per_thread<io_size>()>(data));
192251
}
193252
}
253+
#endif // USE_ROCM
194254

195255
template <
196256
typename func_t,
@@ -237,6 +297,11 @@ static inline void launch_vectorized_kernel(
237297
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
238298
// that causes some numerical mismatches with uint8 on sm80 and sm90.
239299
// TODO: Revisit this after CUDA 12.8 update.
300+
cudaDeviceProp* p = at::cuda::getDeviceProperties(stream.device().index());
301+
const int computeCapability = p->major * 10 + p->minor;
302+
if (computeCapability != 90 && computeCapability != 100) {
303+
vec_size = std::min<uint16_t>(vec_size, 4);
304+
}
240305
if constexpr (sizeof(cpp_type) < 2) {
241306
vec_size = std::min<uint16_t>(vec_size, 4);
242307
}

0 commit comments

Comments
 (0)