Skip to content

Commit b30f178

Browse files
hongxiayangpytorchmergebot
authored andcommitted
Replace assert with CUDA_KERNEL_ASSERT in Reduce.cuh for consistency (#113098)
Related to Fixes #94891 **Problem:** We are trying to disable `printf` in kernels for Pytorch build on ROCm to fix the `torch.sum()` issues for certain community users by disabling `CUDA_KERNEL_ASSERT`, but found that there are still hostcall printf happening in `ReduceSumProdKernel` used by `torch.sum`. **Reason:** The reason is that there are `assert` function calls inside `Reduce.cuh`, ( defined as `__assert_fail` ) which caused `printf`. **Fix:** This pull request is to change `assert` to `CUDA_KERNEL_ASSERT` so that we can consistently disable assertion/printf in cuda/hip kernel code. Pull Request resolved: #113098 Approved by: https://github.com/ezyang
1 parent 77e8e8f commit b30f178

File tree

3 files changed

+6
-9
lines changed

3 files changed

+6
-9
lines changed

aten/src/ATen/native/cuda/PersistentSoftmax.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#pragma once
22

3-
#include <assert.h>
43
#include <cfloat>
54
#include <limits>
65
#include <stdint.h>

aten/src/ATen/native/cuda/Reduce.cuh

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#pragma once
22

3-
#include <assert.h>
43
#include <ATen/core/Array.h>
54
#include <ATen/cuda/CUDAContext.h>
65
#include <ATen/cuda/DeviceUtils.cuh>
@@ -483,7 +482,7 @@ struct ReduceOp {
483482
template <int output_vec_size>
484483
C10_DEVICE at::detail::Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
485484
if (config.vectorize_input) {
486-
assert(output_vec_size == 1);
485+
CUDA_KERNEL_ASSERT(output_vec_size == 1);
487486
// reduce at the header of input_slice where memory is not aligned,
488487
// so that thread_reduce will have an aligned memory to work on.
489488
return {input_vectorized_thread_reduce_impl(data)};
@@ -720,7 +719,7 @@ struct ReduceOp {
720719
out_scalar_t* out, arg_t value,
721720
typename std::enable_if<can_acc>::type* = nullptr
722721
) const {
723-
assert(!final_output);
722+
CUDA_KERNEL_ASSERT(!final_output);
724723
return (out_scalar_t)value;
725724
}
726725

@@ -733,7 +732,7 @@ struct ReduceOp {
733732
at::detail::Array<arg_t, output_vec_size>,
734733
typename std::enable_if<!can_acc>::type* = nullptr
735734
) const {
736-
assert(false); // can't use AT_ASSERT in Cuda.
735+
CUDA_KERNEL_ASSERT(false);
737736
return arg_t {};
738737
}
739738

@@ -745,13 +744,13 @@ struct ReduceOp {
745744
out_scalar_t* out, arg_t value,
746745
typename std::enable_if<!can_acc>::type* = nullptr
747746
) const {
748-
assert(false);
747+
CUDA_KERNEL_ASSERT(false);
749748
return *out;
750749
}
751750

752751
template<class T>
753752
C10_DEVICE void set_results(const T x, const index_t base_offset) const {
754-
assert(noutputs == 1);
753+
CUDA_KERNEL_ASSERT(noutputs == 1);
755754
auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
756755
*res = x;
757756
}
@@ -773,7 +772,7 @@ struct ReduceOp {
773772

774773
template <int output_vec_size>
775774
C10_DEVICE void set_results_to_output(at::detail::Array<arg_t, output_vec_size> value, at::detail::Array<index_t, output_vec_size> base_offset) const {
776-
assert(final_output);
775+
CUDA_KERNEL_ASSERT(final_output);
777776
#pragma unroll
778777
for (int i = 0; i < output_vec_size; i++) {
779778
set_results(ops.project(value[i]), base_offset[i]);

aten/src/ATen/native/cuda/SortingCommon.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#include <ATen/core/TensorBase.h>
33
#include <ATen/ceil_div.h>
44
#include <ATen/NumericUtils.h>
5-
#include <assert.h>
65
#include <c10/macros/Macros.h>
76
#include <stdlib.h>
87
#include <ATen/cuda/detail/IndexUtils.cuh>

0 commit comments

Comments
 (0)