@@ -158,6 +158,69 @@ constexpr auto calc_io_size(){
158158#endif
159159}
160160
161+ #ifndef USE_ROCM
162+ // To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel
163+ // into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be
164+ // used on sm_90 and sm_100 exclusively.
165+ template <int vec_size, typename func_t , typename array_t >
166+ C10_LAUNCH_BOUNDS_1 (num_threads())
167+ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
168+ if constexpr (vec_size == 8 ) {
169+ #if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
170+ using traits = function_traits<func_t >;
171+ constexpr auto io_size = calc_io_size<func_t >();
172+ int remaining = N - io_block_work_size<io_size>() * blockIdx .x ;
173+
174+ if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
175+ // just do a naive unrolled loop
176+ auto input_calc = TrivialOffsetCalculator<traits::arity>();
177+ auto output_calc = TrivialOffsetCalculator<1 >();
178+ auto loader = memory::LoadWithoutCast ();
179+ auto storer = memory::StoreWithoutCast ();
180+ auto policy = memory::policies::unroll<
181+ array_t ,
182+ decltype (input_calc),
183+ decltype (output_calc),
184+ memory::LoadWithoutCast,
185+ memory::StoreWithoutCast,
186+ elems_per_thread<io_size>()>(
187+ data, remaining, input_calc, output_calc, loader, storer);
188+ elementwise_kernel_helper (f, policy);
189+ } else { // if this block has a full `block_work_size` data to handle, use
190+ // vectorized memory access
191+ elementwise_kernel_helper (
192+ f, memory::policies::vectorized<vec_size, array_t , elems_per_thread<io_size>()>(data));
193+ }
194+ #endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
195+ } else {
196+ using traits = function_traits<func_t >;
197+ constexpr auto io_size = calc_io_size<func_t >();
198+ int remaining = N - io_block_work_size<io_size>() * blockIdx .x ;
199+
200+ if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
201+ // just do a naive unrolled loop
202+ auto input_calc = TrivialOffsetCalculator<traits::arity>();
203+ auto output_calc = TrivialOffsetCalculator<1 >();
204+ auto loader = memory::LoadWithoutCast ();
205+ auto storer = memory::StoreWithoutCast ();
206+ auto policy = memory::policies::unroll<
207+ array_t ,
208+ decltype (input_calc),
209+ decltype (output_calc),
210+ memory::LoadWithoutCast,
211+ memory::StoreWithoutCast,
212+ elems_per_thread<io_size>()>(
213+ data, remaining, input_calc, output_calc, loader, storer);
214+ elementwise_kernel_helper (f, policy);
215+ } else { // if this block has a full `block_work_size` data to handle, use
216+ // vectorized memory access
217+ elementwise_kernel_helper (
218+ f, memory::policies::vectorized<vec_size, array_t , elems_per_thread<io_size>()>(data));
219+ }
220+ }
221+ }
222+
223+ #else // USE_ROCM
161224template <int vec_size, typename func_t , typename array_t >
162225C10_LAUNCH_BOUNDS_1 (num_threads())
163226__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
@@ -182,15 +245,12 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
182245 elementwise_kernel_helper (f, policy);
183246 } else { // if this block has a full `block_work_size` data to handle, use
184247 // vectorized memory access
185- #ifdef USE_ROCM
186248 constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
187- #else
188- constexpr auto optimal_vec_size = vec_size;
189- #endif
190249 elementwise_kernel_helper (
191250 f, memory::policies::vectorized<optimal_vec_size, array_t , elems_per_thread<io_size>()>(data));
192251 }
193252}
253+ #endif // USE_ROCM
194254
195255template <
196256 typename func_t ,
@@ -237,6 +297,11 @@ static inline void launch_vectorized_kernel(
237297 // Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
238298 // that causes some numerical mismatches with uint8 on sm80 and sm90.
239299 // TODO: Revisit this after CUDA 12.8 update.
300+ cudaDeviceProp* p = at::cuda::getDeviceProperties (stream.device ().index ());
301+ const int computeCapability = p->major * 10 + p->minor ;
302+ if (computeCapability != 90 && computeCapability != 100 ) {
303+ vec_size = std::min<uint16_t >(vec_size, 4 );
304+ }
240305 if constexpr (sizeof (cpp_type) < 2 ) {
241306 vec_size = std::min<uint16_t >(vec_size, 4 );
242307 }
0 commit comments