Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aten/src/ATen/NumericUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ inline C10_HOST_DEVICE bool _isnan(T val) {

template <typename T, std::enable_if_t<c10::is_complex<T>::value, int> = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return ::isnan(val.real()) || ::isnan(val.imag());
#else
return std::isnan(val.real()) || std::isnan(val.imag());
#endif
}

template <typename T, std::enable_if_t<std::is_same_v<T, at::Half>, int> = 0>
Expand Down
87 changes: 86 additions & 1 deletion aten/src/ATen/cuda/Atomic.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ struct AtomicFPOp<at::Half> {
}
};

template <>
struct AtomicFPOp<c10::complex<float>> {
template <typename func_t>
inline __device__ c10::complex<float> operator() (c10::complex<float> *address, c10::complex<float> val, const func_t& func) {
unsigned long long int* addr_as_ull = (unsigned long long int*)address;
unsigned long long int old = *addr_as_ull;
unsigned long long int assumed, new_val;

c10::complex<float> csum;
do {
assumed = old;
csum = func(csum, val);
new_val = *reinterpret_cast<unsigned long long*>(&csum);
old = atomicCAS(addr_as_ull, assumed, new_val);
} while (assumed != old);

return *reinterpret_cast<c10::complex<float>*>(&addr_as_ull);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't atomic? You need to return csum directly, otherwise the value at addr_as_ull may change underneath you.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact this is also wrong as atomic read-modify-write ops return the old value, not the new value. So this should be bit-casting assumed.

Copy link
Contributor Author

@ZelboK ZelboK May 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact this is also wrong as atomic read-modify-write ops return the old value, not the new value. So this should be bit-casting assumed.

Sorry for the oversight. Could you help me understand? I know that atomicCAS returns the old value but with what in mind are you referring that to?

I understand that addr_as_ull sholdn't be returned, as as another thread can change it correct? Why are we to use assumed though and not csum?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assumed is the value before performing the update, which is what is returned by normal atomicAdd, atomicMax, etc.

See the CAS implementation for half as an example:

hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
return hsum;

Copy link
Contributor Author

@ZelboK ZelboK May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you're right. I got tunnel visioned on the line on the actual call of atomicCAS, yes it should be assumed. i forgot that I am actually implementing an atomic operation here and that it should follow suit lol

}
};

template <>
struct AtomicFPOp<at::BFloat16> {
template <typename func_t>
Expand Down Expand Up @@ -348,6 +368,14 @@ GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)

inline __device__ c10::complex<float> gpuAtomicMul(c10::complex<float> *address, c10::complex<float> val){
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The build failure is totally not your fault as it can't be seen from external CI, we only see it when internal workflows run after the PR is merged

Looking at the failure and pattern matching a bit, it looks like maybe we need __host__ __device__ here as well as for complex_max on 433.

Does this change make sense? I can import the PR and see whether this fixes the internal build tomorrow morning

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately since complex_min and complex_max both use CUDA intrinsics, they won't compile if you make it a __host__ function as well. The use of __fsqrt_rn for example should lead to more performant code/better CUDA assembly. CUDA intrinsics should be taken advantage of imo because it's kernel code and complex numbers are heavier computations to make in general.

The easiest solution would be to add an overload for complex when compiled with CUDA to have an operator*= available with __host__ __device__.

Just adding this

#if defined(__CUDACC__) || defined(__HIPCC__)
  template <typename U>
  C10_HOST_DEVICE constexpr complex<T>& operator*=(const complex<U>& rhs) {
    // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
    T a = real_;
    T b = imag_;
    U c = rhs.real();
    U d = rhs.imag();
    real_ = a * c - b * d;
    imag_ = a * d + b * c;
    return *this;
  }
#endif

in complex.h should fix this problem.

Also on second look I made an oversight in the complex_max and complex_min functions. They should be using regular comparisons and not std::max given it's a __device__ function. So on that note, it's actually good that this PR got reverted! I will push those changes and things should build on your end.

return AtomicFPOp<c10::complex<float>>()(address, val,
[](c10::complex<float> bsum, c10::complex<float> val) {
bsum*=(val);
return bsum;
});
}

inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
return AtomicFPOp<at::Half>()(address, val,
[](at::Half bsum, at::Half val) {
Expand All @@ -369,7 +397,7 @@ inline __device__ double gpuAtomicMul(double * address, double val) {
});
}

// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
// Don't use a templated function for this since the addition function defaults to the CUDA built-in.
inline __device__ float gpuAtomicMul (float * address, float val) {
unsigned int* address_as_ull = (unsigned int*)address;
unsigned int old = *address_as_ull;
Expand Down Expand Up @@ -402,6 +430,28 @@ __host__ __device__ T safe_max(T a, T b) {
return max;
}

__inline__ __device__ c10::complex<float> complex_max(c10::complex<float> a, c10::complex<float> b) {
if(at::_isnan(b)) {
return b;
} else {
// Compute the magnitude of the complex numbers and compare each to see which one is greater.
float a_magnitude = __fsqrt_rn(
(
__fmul_rn(a.real(), a.real()) +
__fmul_rn(a.imag(),a.imag())
)
);
float b_magnitude = __fsqrt_rn(
(
__fmul_rn(b.real(), b.real()) +
__fmul_rn(b.imag(),b.imag())
)
);
return (a_magnitude > b_magnitude) ? a : b;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any precedence for this definition of complex max/min in PyTorch?

Copy link
Contributor Author

@ZelboK ZelboK May 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not experienced enough with Pytorch to answer that. Aside from using magnitudes how else would you order them? I followed convention from other ecosystems and from my research this is how it is done across different disciplines/domains, is it not?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, they cannot be ordered (in mathematical terms, complex numbers are not an ordered field)
We should error in these cases, same as we error when we call max on a complex tensor. If people want to use these ops on complex tensors, they can do a view_as_real and perform some transformations on the output to define the order they want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was under the impression that some contexts will use magnitude for ordering complex numbers, like spectral analysis for DSP. I also took motivation from https://www.mathworks.com/help/matlab/ref/max.html as well

@Franklalalala

Could you comment on whether or not you had a use case for scattering complex numbers? What kind of work were you trying to do? Would you know if ordering of complex numbers is practically useful?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be implemented and it may be useful, but we don't implement that in PyTorch at a kernel level. As mentioned above, all these orderings can often be simulated with the current API and a bit of imagination :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - in that case I'll wait until @mikaylagawarecki has a chance to review again. Thanks for taking a look!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZelboK I am working on Tensor_network, which require a series of matrix multiplication. In the case of complex elements, the torch scatter connot be used in GPU. As far as I concerned right now, we do not use sortage here, just elements' multiplication.
By the way, we have worked a way out, that is, we transform the complex number through Euler transformation and turns the multiplication to addition of angles and multiplication of magnitude.
The excellent work of you guys has reached out of my knowledge base, I connot give anymore advices. But thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZelboK I am working on Tensor_network, which require a series of matrix multiplication. In the case of complex elements, the torch scatter connot be used in GPU. As far as I concerned right now, we do not use sortage here, just elements' multiplication. By the way, we have worked a way out, that is, we transform the complex number through Euler transformation and turns the multiplication to addition of angles and multiplication of magnitude. The excellent work of you guys has reached out of my knowledge base, I connot give anymore advices. But thanks!

Thanks a lot for responding, I was genuinely curious. This helps give me perspective :)

}
}


ATOMIC_INTEGER_IMPL(Max)
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
Expand All @@ -416,6 +466,13 @@ inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
});
}

inline __device__ c10::complex<float> gpuAtomicMax(c10::complex<float> * address, c10::complex<float> val) {
return AtomicFPOp<c10::complex<float>>()(address, val,
[](c10::complex<float> bsum, c10::complex<float> val) {
return complex_max(bsum, val);
});
}

inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
return AtomicFPOp<at::BFloat16>()(address, val,
[](at::BFloat16 bsum, at::BFloat16 val) {
Expand Down Expand Up @@ -462,6 +519,27 @@ __host__ __device__ T safe_min(T a, T b) {
return min;
}

__inline__ __device__ c10::complex<float> complex_min(c10::complex<float> a, c10::complex<float> b) {
if(at::_isnan(b)) {
return b;
} else {
// Compute the magnitude of the complex numbers and compare each to see which one is smaller.
float a_magnitude = __fsqrt_rn(
(
__fmul_rn(a.real(), a.real()) +
__fmul_rn(a.imag(),a.imag())
)
);
float b_magnitude = __fsqrt_rn(
(
__fmul_rn(b.real(), b.real()) +
__fmul_rn(b.imag(),b.imag())
)
);
return (a_magnitude < b_magnitude) ? a : b;
}
}

ATOMIC_INTEGER_IMPL(Min)
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
Expand All @@ -476,6 +554,13 @@ inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
});
}

inline __device__ c10::complex<float> gpuAtomicMin(c10::complex<float> * address, c10::complex<float> val) {
return AtomicFPOp<c10::complex<float>>()(address, val,
[](c10::complex<float> bsum, c10::complex<float> val) {
return complex_min(bsum, val);
});
}

inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
return AtomicFPOp<at::BFloat16>()(address, val,
[](at::BFloat16 bsum, at::BFloat16 val) {
Expand Down
10 changes: 4 additions & 6 deletions aten/src/ATen/native/cuda/ScatterGatherKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/MemoryOverlap.h>

#include <ATen/native/ScatterGatherChecks.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/TensorIterator.h>
Expand Down Expand Up @@ -201,7 +200,6 @@ struct cuda_scatter_gather_base_kernel {
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;


AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(),
Expand Down Expand Up @@ -259,7 +257,6 @@ struct cuda_scatter_gather_base_kernel {
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;


AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(),
Expand Down Expand Up @@ -318,9 +315,9 @@ struct cuda_scatter_gather_base_kernel {
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;


AT_DISPATCH_ALL_TYPES_AND2(
AT_DISPATCH_ALL_TYPES_AND3(
at::ScalarType::Half, at::ScalarType::BFloat16,
at::ScalarType::ComplexFloat,
iter.dtype(),
"cuda_scatter_gather_base_kernel_func", [&] {
using dtype = typename std::conditional<cast_to_opaque,
Expand Down Expand Up @@ -450,8 +447,9 @@ struct cuda_scatter_fill_base_kernel {
auto index_size = ensure_nonempty_size(self, dim);
auto index_stride = ensure_nonempty_stride(self, dim);

AT_DISPATCH_ALL_TYPES_AND2(
AT_DISPATCH_ALL_TYPES_AND3(
at::ScalarType::Half, at::ScalarType::BFloat16,
at::ScalarType::ComplexFloat,
iter.dtype(),
"cuda_scatter_fill_base_kernel_reduce_multiply", [&] {
using dtype = typename std::conditional<cast_to_opaque,
Expand Down
12 changes: 8 additions & 4 deletions test/test_scatter_gather_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,17 @@ def test_scatter_reduce_sum(self, device, dtype):
include_self=include_self)

@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
include_complex=False, include_bool=False))
def test_scatter_reduce_prod(self, device, dtype):
for include_self in (True, False):
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
is_scalar=False, reduction='prod', unique_indices=False,
include_self=include_self)

@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
include_complex=False, include_bool=False))
def test_scatter_reduce_mean(self, device, dtype):
for include_self in (True, False):
for deterministic in [False, True]:
Expand All @@ -239,7 +241,8 @@ def test_scatter_reduce_mean(self, device, dtype):
include_self=include_self)

@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
include_complex=False, include_bool=False))
def test_scatter_reduce_amax(self, device, dtype):
for include_self in (True, False):
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
Expand All @@ -258,7 +261,8 @@ def test_scatter_reduce_amax(self, device, dtype):


@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
include_complex=False, include_bool=False))
def test_scatter_reduce_amin(self, device, dtype):
for include_self in (True, False):
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
Expand Down
6 changes: 3 additions & 3 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
_create_scaling_case, _create_scaling_models_optimizers)
from torch.testing._internal.common_mkldnn import bf32_on_and_off
from torch.testing._internal.common_dtype import (
floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types,
all_types_and, floating_types, floating_and_complex_types, integral_types_and,
floating_types_and, get_all_math_dtypes, all_types_and_complex_and, all_types_and, floating_types,
floating_and_complex_types, integral_types_and,
get_all_qint_dtypes,
)
from torch.testing._internal.two_tensor import TwoTensor
Expand Down Expand Up @@ -3837,7 +3837,7 @@ def test_scatter_reduce_non_unique_index(self, device, dtype):
self.assertEqual(input, result, msg=f"result: {result} input: {input} method: {str(operation)}")

@onlyCUDA
@dtypes(*complex_types())
@dtypes(torch.cdouble)
def test_scatter_reduce_multiply_unsupported_dtypes(self, device, dtype):
height = 2
width = 2
Expand Down