Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
}
}

#ifdef USE_ROCM
#if defined(USE_ROCM)
template <int io_sizes>
constexpr auto elems_per_thread(){
if constexpr (io_sizes == 1) {
Expand Down Expand Up @@ -219,7 +219,7 @@ static inline void launch_vectorized_kernel(
constexpr auto io_size = calc_io_size<func_t>();
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
auto stream = at::cuda::getCurrentCUDAStream();
#ifdef USE_ROCM
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION < 12080)
int vec_size = memory::can_vectorize_up_to<func_t>(data);
#else
using cpp_type = typename function_traits<func_t>::result_type;
Expand All @@ -241,11 +241,13 @@ static inline void launch_vectorized_kernel(
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
#endif
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
case 8:
vectorized_elementwise_kernel<8, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
#endif
case 4:
vectorized_elementwise_kernel<4, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);
Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/native/cuda/Dropout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,15 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) {
#ifdef USE_ROCM
// make sure we don't break assumption that we can't have > 16 elements / thread
TORCH_INTERNAL_ASSERT(vec_size <= 16, "Value of VEC must be in [2, 4, 8, 16]");
#else
#elif (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
const int optimal_vec_size = 16 / static_cast<int>(sizeof(scalar_t));
vec_size = std::min<int>(optimal_vec_size, vec_size);

// make sure we don't break assumption that we can't have > 4 elements / thread
TORCH_INTERNAL_ASSERT(vec_size <= 8, "Value of VEC must be in [2, 4, 8]");
#else
// make sure we don't break assumption that we can't have > 4 elements / thread
TORCH_INTERNAL_ASSERT(vec_size <= 4, "Value of VEC must be in [2, 4]");
#endif
}

Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/cuda/MemoryAccess.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,9 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec2_alignment = std::alignment_of_v<aligned_vector<scalar_t, 2>>;
constexpr int vec4_alignment = std::alignment_of_v<aligned_vector<scalar_t, 4>>;
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should USE_ROCM here also be inverted if the CUDA_VERSION condition is >= 12080

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't think so. Before #145746 vec8_alignment were only available to USE_ROCM, after it was enabled unconditionally and I want it to be enabled for either ROCM or CUDA newer than 12.6

constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
#endif
#ifdef USE_ROCM
constexpr int vec16_alignment = std::alignment_of_v<aligned_vector<scalar_t, 16>>;
constexpr int type_size = sizeof(scalar_t);
Expand All @@ -495,7 +497,7 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
} else if (type_size <= 2 && (address % vec8_alignment == 0)) {
return 8;
} else
#else
#elif defined(CUDA_VERSION) && CUDA_VERSION >= 12080
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't there be some logic to handle the case when CUDA_VERSION < 12080?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ZainRizvi this is basically redoing #145746 only if CUDA >= 12.8

Copy link
Contributor

@atalman atalman Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hence this code should not be applied by default but only for CUDA 12.8+

if (address % vec8_alignment == 0) {
   return 8;
  } else

if (address % vec8_alignment == 0) {
return 8;
} else
Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/native/cuda/thread_constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ constexpr int thread_work_size() { return 4; }
constexpr uint32_t num_threads() {
return C10_WARP_SIZE * 4;
}

#if defined(CUDA_VERSION) && CUDA_VERSION < 12080
constexpr int thread_work_size() { return 4; }
#else
constexpr int thread_work_size() { return 8; }
#endif
#endif

constexpr int block_work_size() { return thread_work_size() * num_threads(); }
19 changes: 12 additions & 7 deletions aten/src/ATen/test/cuda_vectorized_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,17 @@ TEST(TestLoops, HasSameArgTypes) {

TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
char *ptr = reinterpret_cast<char *>(buffer1);
#if defined(CUDA_VERSION) && CUDA_VERSION < 12080
constexpr auto vectorize_limit = 4;
#else
constexpr auto vectorize_limit= 8;
#endif

ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), vectorize_limit);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), vectorize_limit);
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), vectorize_limit);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), vectorize_limit);
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), vectorize_limit);

ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 1), 1);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 1), 1);
Expand All @@ -65,8 +70,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 4), 2);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 4), 1);

ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 8);
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), vectorize_limit);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), vectorize_limit);
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 8), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 8), 2);
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr + 8), 1);
Expand Down
Loading