Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 92 additions & 10 deletions aten/src/ATen/native/cuda/LogcumsumexpKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,106 @@

namespace at::native {

// custom min and max to be used in logcumsumexp for complex arguments
template <typename scalar_t, bool min>
__host__ __device__ c10::complex<scalar_t> _logcumsumexp_minmax(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually you can revert the templating arg too, its' a bit difficult to setup this in a constexpr if statement that is clean with the all the non-constexpr conditions as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also all the else statements are unnecessary since they all have return statements in them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What difference does it make if we remove the else statements?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mfkasim1 just removes extra indentation. That's why it's a nit. Don't really care either way.

scalar_t xr = std::real(x);
scalar_t yr = std::real(y);
if (::isnan(yr) || (::isnan(std::imag(y)))) {
return y;
} else if (::isnan(xr) || (::isnan(std::imag(x)))) {
return x;
} else if (min) { // min
return (xr < yr) ? x : y;
} else { // max
return (xr >= yr) ? x : y;
}
}

template <typename scalar_t>
__host__ __device__ scalar_t _log_add_exp_helper(const scalar_t& x, const scalar_t& y) {
// Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
// Using the original expression: `at::_isnan(y) ? y : std::min(x, y)` causes an error in ROCM
auto isnan_x = at::_isnan(x);
auto isnan_y = at::_isnan(y);
scalar_t min = isnan_y ? y : (isnan_x ? x : std::min(x, y));
scalar_t max = isnan_y ? y : (isnan_x ? x : std::max(x, y));
if (min != max || ::isfinite(min)) {
// nan will be propagated here
return ::log1p(std::exp(min - max)) + max;
} else {
// special case to correctly handle infinite cases
return x;
}
}

template <typename scalar_t>
__host__ __device__ c10::complex<scalar_t> _fast_build_exp(const c10::complex<scalar_t>& x) {
// complex exponential function, but implemented manually to get fast compilation time
// this function only handles the case where the x is finite (not inf nor nan)
auto xreal = std::real(x);
auto ximag = std::imag(x);
auto exp_x_abs = std::exp(xreal);
auto exp_x_real = exp_x_abs * std::cos(ximag);
auto exp_x_imag = exp_x_abs * std::sin(ximag);
return {exp_x_real, exp_x_imag};
}

template <typename scalar_t>
__host__ __device__ c10::complex<scalar_t> _fast_build_exp_inf(const c10::complex<scalar_t>& x) {
// complex exponential function, but implemented manually to get fast compilation time
// this function only handles the case where the real part of x is infinite
auto ximag = std::imag(x);
auto exp_x_abs = std::numeric_limits<scalar_t>::infinity();
auto sin = std::sin(ximag);
auto cos = std::cos(ximag);
// special case if the angle is exactly the multiple of pi/2
auto exp_x_real = (cos == 0) ? (scalar_t)0.0 : exp_x_abs * cos;
auto exp_x_imag = (sin == 0) ? (scalar_t)0.0 : exp_x_abs * sin;
return {exp_x_real, exp_x_imag};
}

template <typename scalar_t>
__host__ __device__ c10::complex<scalar_t> _log_add_exp_helper(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
c10::complex<scalar_t> min = _logcumsumexp_minmax<scalar_t, /*min=*/true>(x, y);
c10::complex<scalar_t> max = _logcumsumexp_minmax<scalar_t, /*min=*/false>(x, y);
scalar_t min_real = std::real(min);
scalar_t max_real = std::real(max);

if (::isnan(min_real) || ::isnan(std::imag(min))) {
// handling the "infectious" NaNs
return {std::numeric_limits<scalar_t>::quiet_NaN(), std::numeric_limits<scalar_t>::quiet_NaN()};
}
else if ((!::isfinite(min_real)) && (min_real == max_real)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but a lot of the elses also aren't needed here due since it's all just dealing with early returns

if (min_real < 0) {
// handle the -inf case, the imaginary part here does not really matter as the exp(value)
// will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined.
// It does not matter if we're taking the exp of this value
return min;
} else {
// handle the +inf case, we don't need the special precision for log1p for small values
// and to avoid producing nan in case of real(max) == real(min) == +inf
auto exp_min = _fast_build_exp_inf(min);
auto exp_max = _fast_build_exp_inf(max);
return ::log1p(exp_min + exp_max - 1); // log1p(x - 1) builds faster than log
}
} else {
auto minmax = min - max;
auto exp_minmax = _fast_build_exp(minmax);
return ::log1p(exp_minmax) + max;
}
}

void launch_logcumsumexp_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) {
AT_DISPATCH_FLOATING_TYPES_AND2(
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
self.scalar_type(), "logcumsumexp_cuda",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
scalar_t init = -std::numeric_limits<scalar_t>::infinity();
auto log_add_exp = [] C10_HOST_DEVICE (const scalar_t x_, const scalar_t y_) -> scalar_t {
const opmath_t x{x_}, y{y_};
auto min = at::_isnan(y) ? y : std::min<opmath_t>(x, y); //std::min returns first arg if one of the args is nan
auto max = at::_isnan(y) ? y : std::max<opmath_t>(x, y); //std::max returns first arg if one of the args is nan
if (min != max || ::isfinite(min)) {
// nan will be propagated here
return ::log1p(std::exp(min - max)) + max;
} else {
// special case to correctly handle infinite inputs
return x;
}
return _log_add_exp_helper(x, y);
};
scan_dim<scalar_t>(self, result, dim, init, log_add_exp);
});
Expand Down
1 change: 0 additions & 1 deletion test/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ def test_logsumexp(self, device):
self.assertEqual(expected.shape, actual.shape)
self.assertEqual(expected, actual)

@onlyCPU
@skipIfNoSciPy
@dtypes(torch.complex64, torch.complex128)
def test_logcumsumexp_complex(self, device, dtype):
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16290,9 +16290,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
),
OpInfo('logcumsumexp',
dtypes=floating_and_complex_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
backward_dtypes=floating_and_complex_types_and(torch.bfloat16),
backward_dtypesIfCUDA=floating_types_and(torch.bfloat16),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16),
skips=(
# AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cuda'),
Expand Down