Skip to content

Commit c02aba6

Browse files
committed
Revert "[ATen][CUDA] Implement 128 bit vectorization v2 (#145746)"
This reverts commit e84bf88.
1 parent 861d2cc commit c02aba6

File tree

8 files changed

+21
-77
lines changed

8 files changed

+21
-77
lines changed

aten/src/ATen/native/cuda/CUDAJitLoops.cuh

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ struct JittedVecKernelCache {
4949
at::cuda::jit::NvrtcFunction vec1;
5050
at::cuda::jit::NvrtcFunction vec2;
5151
at::cuda::jit::NvrtcFunction vec4;
52-
at::cuda::jit::NvrtcFunction vec8;
5352
#ifdef USE_ROCM
53+
at::cuda::jit::NvrtcFunction vec8;
5454
at::cuda::jit::NvrtcFunction vec16;
5555
#endif
5656

@@ -131,30 +131,18 @@ void launch_jitted_vectorized_kernel(
131131
int vec_size = at::cuda::jit::can_vectorize_up_to(
132132
desc, c10::ArrayRef<char*>(data.data(), data.size()));
133133

134-
#ifndef USE_ROCM
135-
const auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize();
136-
const int optimal_vec_size = 16 / static_cast<int>(input_size);
137-
vec_size = std::min<int>(optimal_vec_size, vec_size);
138-
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
139-
// that causes some numerical mismatches with uint8 on sm80 and sm90.
140-
// TODO: Revisit this after CUDA 12.8 update.
141-
if (input_size < 2) {
142-
vec_size = std::min<int>(vec_size, 4);
143-
}
144-
#endif
145-
146134
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
147135
// fn_ptr is set to the appropriate function based on the vec size and GPU used
148136
at::cuda::jit::NvrtcFunction* fn_ptr = nullptr;
149137

150138
#ifdef USE_ROCM
151139
if (vec_size == 16) {
152140
fn_ptr = &fn_cache.vec16;
141+
} else if (vec_size == 8) {
142+
fn_ptr = &fn_cache.vec8;
153143
} else
154144
#endif
155-
if (vec_size == 8) {
156-
fn_ptr = &fn_cache.vec8;
157-
} else if (vec_size == 4) {
145+
if (vec_size == 4) {
158146
fn_ptr = &fn_cache.vec4;
159147
} else if (vec_size == 2) {
160148
fn_ptr = &fn_cache.vec2;

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
7878
}
7979
}
8080

81-
#ifdef USE_ROCM
8281
template <int io_sizes>
8382
constexpr auto elems_per_thread(){
8483
if constexpr (io_sizes == 1) {
@@ -89,16 +88,6 @@ constexpr auto elems_per_thread(){
8988
return 4;
9089
}
9190
}
92-
#else
93-
template <int io_sizes>
94-
constexpr auto elems_per_thread(){
95-
if constexpr (io_sizes == 1) {
96-
return 16;
97-
} else {
98-
return 8;
99-
}
100-
}
101-
#endif
10291

10392
template <int io_sizes>
10493
constexpr auto io_block_work_size() {
@@ -219,33 +208,21 @@ static inline void launch_vectorized_kernel(
219208
constexpr auto io_size = calc_io_size<func_t>();
220209
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
221210
auto stream = at::cuda::getCurrentCUDAStream();
222-
#ifdef USE_ROCM
223211
int vec_size = memory::can_vectorize_up_to<func_t>(data);
224-
#else
225-
using cpp_type = typename function_traits<func_t>::result_type;
226-
const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t>(data);
227-
uint16_t vec_size = 16 / static_cast<uint16_t>(sizeof(cpp_type));
228-
vec_size = std::min<uint16_t>(vec_size, max_vec_size);
229-
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
230-
// that causes some numerical mismatches with uint8 on sm80 and sm90.
231-
// TODO: Revisit this after CUDA 12.8 update.
232-
if constexpr (sizeof(cpp_type) < 2) {
233-
vec_size = std::min<uint16_t>(vec_size, 4);
234-
}
235-
#endif
212+
236213
switch (vec_size) {
237214
#ifdef USE_ROCM
238215
case 16:
239216
vectorized_elementwise_kernel<16, func_t, array_t>
240217
<<<grid, num_threads(), 0, stream>>>(N, f, data);
241218
C10_CUDA_KERNEL_LAUNCH_CHECK();
242219
break;
243-
#endif
244220
case 8:
245221
vectorized_elementwise_kernel<8, func_t, array_t>
246222
<<<grid, num_threads(), 0, stream>>>(N, f, data);
247223
C10_CUDA_KERNEL_LAUNCH_CHECK();
248224
break;
225+
#endif
249226
case 4:
250227
vectorized_elementwise_kernel<4, func_t, array_t>
251228
<<<grid, num_threads(), 0, stream>>>(N, f, data);

aten/src/ATen/native/cuda/Dropout.cu

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -217,11 +217,8 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) {
217217
// make sure we don't break assumption that we can't have > 16 elements / thread
218218
TORCH_INTERNAL_ASSERT(vec_size <= 16, "Value of VEC must be in [2, 4, 8, 16]");
219219
#else
220-
const int optimal_vec_size = 16 / static_cast<int>(sizeof(scalar_t));
221-
vec_size = std::min<int>(optimal_vec_size, vec_size);
222-
223220
// make sure we don't break assumption that we can't have > 4 elements / thread
224-
TORCH_INTERNAL_ASSERT(vec_size <= 8, "Value of VEC must be in [2, 4, 8]");
221+
TORCH_INTERNAL_ASSERT(vec_size <= 4, "Value of VEC must be in [2, 4]");
225222
#endif
226223
}
227224

aten/src/ATen/native/cuda/MemoryAccess.cuh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -486,19 +486,15 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
486486
uint64_t address = reinterpret_cast<uint64_t>(pointer);
487487
constexpr int vec2_alignment = std::alignment_of_v<aligned_vector<scalar_t, 2>>;
488488
constexpr int vec4_alignment = std::alignment_of_v<aligned_vector<scalar_t, 4>>;
489-
constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
490489
#ifdef USE_ROCM
490+
constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
491491
constexpr int vec16_alignment = std::alignment_of_v<aligned_vector<scalar_t, 16>>;
492492
constexpr int type_size = sizeof(scalar_t);
493493
if (type_size == 1 && (address % vec16_alignment == 0)) {
494494
return 16;
495495
} else if (type_size <= 2 && (address % vec8_alignment == 0)) {
496496
return 8;
497497
} else
498-
#else
499-
if (address % vec8_alignment == 0) {
500-
return 8;
501-
} else
502498
#endif
503499
if (address % vec4_alignment == 0) {
504500
return 4;

aten/src/ATen/native/cuda/jit_utils.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ void initializeCudaContext() {
932932
}
933933
}
934934

935+
#ifdef USE_ROCM
935936
int calc_io_size(
936937
const int nInputs,
937938
const int nOutputs,
@@ -951,6 +952,7 @@ int calc_io_size(
951952

952953
return 0;
953954
}
955+
#endif
954956

955957
int calc_thread_work_size(
956958
const int nInputs,
@@ -969,14 +971,7 @@ int calc_thread_work_size(
969971
}
970972
return io_size;
971973
#else
972-
auto io_size = at::cuda::jit::calc_io_size(nInputs, nOutputs, inputs_type, result_type);
973-
TORCH_INTERNAL_ASSERT(io_size > 0);
974-
if (io_size == 1) {
975-
return 16;
976-
} else {
977-
return 8;
978-
}
979-
return io_size;
974+
return JIT_THREAD_WORK_SIZE;
980975
#endif
981976
}
982977

aten/src/ATen/native/cuda/jit_utils.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ inline int can_vectorize_up_to(size_t default_alignment, void *pointer) {
6060
if ((default_alignment <= 2) && (ip % (8 * default_alignment) == 0)) {
6161
return 8;
6262
}
63-
#else
64-
if (ip % (8 * default_alignment) == 0) {
65-
return 8;
66-
}
6763
#endif
6864
if (ip % (4 * default_alignment) == 0) {
6965
return 4;
@@ -92,17 +88,15 @@ inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef<char*
9288
}
9389

9490
//FIXME - this are defined in Loops.cuh, but including Loops.cuh here would lead to circular includes Loops.cuh -> CUDALoops.cuh -> jit_utils.h -> Loops.cuh
95-
#ifdef USE_ROCM
9691
#define JIT_THREAD_WORK_SIZE 4
97-
#else
98-
#define JIT_THREAD_WORK_SIZE 8
99-
#endif
10092

93+
#ifdef USE_ROCM
10194
int calc_io_size(
10295
const int nInputs,
10396
const int nOutputs,
10497
const c10::ScalarType& inputs_type,
10598
const c10::ScalarType& result_type);
99+
#endif
106100

107101
int calc_thread_work_size(
108102
const int nInputs,

aten/src/ATen/native/cuda/thread_constants.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,11 @@
1212
constexpr int num_threads() {
1313
return 256;
1414
}
15-
16-
constexpr int thread_work_size() { return 4; }
1715
#else
1816
constexpr uint32_t num_threads() {
1917
return C10_WARP_SIZE * 4;
2018
}
21-
22-
constexpr int thread_work_size() { return 8; }
2319
#endif
2420

21+
constexpr int thread_work_size() { return 4; }
2522
constexpr int block_work_size() { return thread_work_size() * num_threads(); }

aten/src/ATen/test/cuda_vectorized_test.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ TEST(TestLoops, HasSameArgTypes) {
4747
TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
4848
char *ptr = reinterpret_cast<char *>(buffer1);
4949

50-
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 8);
51-
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 8);
52-
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 8);
53-
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 8);
54-
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 8);
50+
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 4);
51+
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 4);
52+
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 4);
53+
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 4);
54+
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 4);
5555

5656
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 1), 1);
5757
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 1), 1);
@@ -65,8 +65,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
6565
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 4), 2);
6666
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 4), 1);
6767

68-
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 8);
69-
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 8);
68+
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 4);
69+
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 4);
7070
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 8), 4);
7171
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 8), 2);
7272
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr + 8), 1);

0 commit comments

Comments
 (0)