@@ -133,6 +133,69 @@ constexpr auto calc_io_size(){
133133#endif
134134}
135135
136+ #ifndef USE_ROCM
137+ // To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel
138+ // into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be
139+ // used on sm_90 and sm_100 exclusively.
140+ template <int vec_size, typename func_t , typename array_t >
141+ C10_LAUNCH_BOUNDS_1 (num_threads())
142+ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
143+ if constexpr (vec_size == 8 ) {
144+ #if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
145+ using traits = function_traits<func_t >;
146+ constexpr auto io_size = calc_io_size<func_t >();
147+ int remaining = N - io_block_work_size<io_size>() * blockIdx .x ;
148+
149+ if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
150+ // just do a naive unrolled loop
151+ auto input_calc = TrivialOffsetCalculator<traits::arity>();
152+ auto output_calc = TrivialOffsetCalculator<1 >();
153+ auto loader = memory::LoadWithoutCast ();
154+ auto storer = memory::StoreWithoutCast ();
155+ auto policy = memory::policies::unroll<
156+ array_t ,
157+ decltype (input_calc),
158+ decltype (output_calc),
159+ memory::LoadWithoutCast,
160+ memory::StoreWithoutCast,
161+ elems_per_thread<io_size>()>(
162+ data, remaining, input_calc, output_calc, loader, storer);
163+ elementwise_kernel_helper (f, policy);
164+ } else { // if this block has a full `block_work_size` data to handle, use
165+ // vectorized memory access
166+ elementwise_kernel_helper (
167+ f, memory::policies::vectorized<vec_size, array_t , elems_per_thread<io_size>()>(data));
168+ }
169+ #endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
170+ } else {
171+ using traits = function_traits<func_t >;
172+ constexpr auto io_size = calc_io_size<func_t >();
173+ int remaining = N - io_block_work_size<io_size>() * blockIdx .x ;
174+
175+ if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
176+ // just do a naive unrolled loop
177+ auto input_calc = TrivialOffsetCalculator<traits::arity>();
178+ auto output_calc = TrivialOffsetCalculator<1 >();
179+ auto loader = memory::LoadWithoutCast ();
180+ auto storer = memory::StoreWithoutCast ();
181+ auto policy = memory::policies::unroll<
182+ array_t ,
183+ decltype (input_calc),
184+ decltype (output_calc),
185+ memory::LoadWithoutCast,
186+ memory::StoreWithoutCast,
187+ elems_per_thread<io_size>()>(
188+ data, remaining, input_calc, output_calc, loader, storer);
189+ elementwise_kernel_helper (f, policy);
190+ } else { // if this block has a full `block_work_size` data to handle, use
191+ // vectorized memory access
192+ elementwise_kernel_helper (
193+ f, memory::policies::vectorized<vec_size, array_t , elems_per_thread<io_size>()>(data));
194+ }
195+ }
196+ }
197+
198+ #else // USE_ROCM
136199template <int vec_size, typename func_t , typename array_t >
137200C10_LAUNCH_BOUNDS_1 (num_threads())
138201__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
@@ -157,15 +220,12 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
157220 elementwise_kernel_helper (f, policy);
158221 } else { // if this block has a full `block_work_size` data to handle, use
159222 // vectorized memory access
160- #ifdef USE_ROCM
161223 constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
162- #else
163- constexpr auto optimal_vec_size = vec_size;
164- #endif
165224 elementwise_kernel_helper (
166225 f, memory::policies::vectorized<optimal_vec_size, array_t , elems_per_thread<io_size>()>(data));
167226 }
168227}
228+ #endif // USE_ROCM
169229
170230template <
171231 typename func_t ,
@@ -212,6 +272,11 @@ static inline void launch_vectorized_kernel(
212272 // Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
213273 // that causes some numerical mismatches with uint8 on sm80 and sm90.
214274 // TODO: Revisit this after CUDA 12.8 update.
275+ cudaDeviceProp* p = at::cuda::getDeviceProperties (stream.device ().index ());
276+ const int computeCapability = p->major * 10 + p->minor ;
277+ if (computeCapability != 90 && computeCapability != 100 ) {
278+ vec_size = std::min<uint16_t >(vec_size, 4 );
279+ }
215280 if constexpr (sizeof (cpp_type) < 2 ) {
216281 vec_size = std::min<uint16_t >(vec_size, 4 );
217282 }
0 commit comments