Skip to content

Commit 9e24c26

Browse files
ZelboKpytorchmergebot
authored andcommitted
Include support for the scatter gather cuda kernels to allow for comp… (#124809)
Fixes #121965 This PR hopes to add support complex numbers in the scatter/gather related kernels. For brevity, I will only include `complex<float>` for now as `complex<double>`, for example, will be more complicated. C++ unit tests are currently passing alongside tests in `test_scatter_gather_ops.py`. Python test suites also seem to be passing. Please keep the following in mind: 1) I think this is my first time using Pytorch. 2) This is my first contribution to Pytorch. Environment: 3080 & WSL 2. `nvcc` is at 12.4. Pull Request resolved: #124809 Approved by: https://github.com/mikaylagawarecki
1 parent f1f142c commit 9e24c26

File tree

5 files changed

+106
-14
lines changed

5 files changed

+106
-14
lines changed

aten/src/ATen/NumericUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ inline C10_HOST_DEVICE bool _isnan(T val) {
3838

3939
template <typename T, std::enable_if_t<c10::is_complex<T>::value, int> = 0>
4040
inline C10_HOST_DEVICE bool _isnan(T val) {
41+
#if defined(__CUDACC__) || defined(__HIPCC__)
42+
return ::isnan(val.real()) || ::isnan(val.imag());
43+
#else
4144
return std::isnan(val.real()) || std::isnan(val.imag());
45+
#endif
4246
}
4347

4448
template <typename T, std::enable_if_t<std::is_same_v<T, at::Half>, int> = 0>

aten/src/ATen/cuda/Atomic.cuh

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,26 @@ struct AtomicFPOp<at::Half> {
3535
}
3636
};
3737

38+
template <>
39+
struct AtomicFPOp<c10::complex<float>> {
40+
template <typename func_t>
41+
inline __device__ c10::complex<float> operator() (c10::complex<float> *address, c10::complex<float> val, const func_t& func) {
42+
unsigned long long int* addr_as_ull = (unsigned long long int*)address;
43+
unsigned long long int old = *addr_as_ull;
44+
unsigned long long int assumed, new_val;
45+
46+
c10::complex<float> csum;
47+
do {
48+
assumed = old;
49+
csum = func(csum, val);
50+
new_val = *reinterpret_cast<unsigned long long*>(&csum);
51+
old = atomicCAS(addr_as_ull, assumed, new_val);
52+
} while (assumed != old);
53+
54+
return *reinterpret_cast<c10::complex<float>*>(&addr_as_ull);
55+
}
56+
};
57+
3858
template <>
3959
struct AtomicFPOp<at::BFloat16> {
4060
template <typename func_t>
@@ -348,6 +368,14 @@ GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
348368
GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
349369
GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)
350370

371+
inline __device__ c10::complex<float> gpuAtomicMul(c10::complex<float> *address, c10::complex<float> val){
372+
return AtomicFPOp<c10::complex<float>>()(address, val,
373+
[](c10::complex<float> bsum, c10::complex<float> val) {
374+
bsum*=(val);
375+
return bsum;
376+
});
377+
}
378+
351379
inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
352380
return AtomicFPOp<at::Half>()(address, val,
353381
[](at::Half bsum, at::Half val) {
@@ -369,7 +397,7 @@ inline __device__ double gpuAtomicMul(double * address, double val) {
369397
});
370398
}
371399

372-
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
400+
// Don't use a templated function for this since the addition function defaults to the CUDA built-in.
373401
inline __device__ float gpuAtomicMul (float * address, float val) {
374402
unsigned int* address_as_ull = (unsigned int*)address;
375403
unsigned int old = *address_as_ull;
@@ -402,6 +430,29 @@ __host__ __device__ T safe_max(T a, T b) {
402430
return max;
403431
}
404432

433+
__inline__ __device__ c10::complex<float> complex_max(c10::complex<float> a, c10::complex<float> b) {
434+
if(at::_isnan(b)) {
435+
return b;
436+
} else {
437+
// Compute the magnitude of the complex numbers and compare each to see which one is greater.
438+
float a_magnitude = __fsqrt_rn(
439+
(
440+
__fmul_rn(a.real(), a.real()) +
441+
__fmul_rn(a.imag(),a.imag())
442+
)
443+
);
444+
float b_magnitude = __fsqrt_rn(
445+
(
446+
__fmul_rn(b.real(), b.real()) +
447+
__fmul_rn(b.imag(),b.imag())
448+
)
449+
);
450+
return std::max<float>(a_magnitude, b_magnitude);
451+
}
452+
453+
}
454+
455+
405456
ATOMIC_INTEGER_IMPL(Max)
406457
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
407458
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
@@ -416,6 +467,13 @@ inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
416467
});
417468
}
418469

470+
inline __device__ c10::complex<float> gpuAtomicMax(c10::complex<float> * address, c10::complex<float> val) {
471+
return AtomicFPOp<c10::complex<float>>()(address, val,
472+
[](c10::complex<float> bsum, c10::complex<float> val) {
473+
return complex_max(bsum, val);
474+
});
475+
}
476+
419477
inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
420478
return AtomicFPOp<at::BFloat16>()(address, val,
421479
[](at::BFloat16 bsum, at::BFloat16 val) {
@@ -462,6 +520,27 @@ __host__ __device__ T safe_min(T a, T b) {
462520
return min;
463521
}
464522

523+
__inline__ __device__ c10::complex<float> complex_min(c10::complex<float> a, c10::complex<float> b) {
524+
if(at::_isnan(b)) {
525+
return b;
526+
} else {
527+
// Compute the magnitude of the complex numbers and compare each to see which one is smaller.
528+
float a_magnitude = __fsqrt_rn(
529+
(
530+
__fmul_rn(a.real(), a.real()) +
531+
__fmul_rn(a.imag(),a.imag())
532+
)
533+
);
534+
float b_magnitude = __fsqrt_rn(
535+
(
536+
__fmul_rn(b.real(), b.real()) +
537+
__fmul_rn(b.imag(),b.imag())
538+
)
539+
);
540+
return std::min<float>(a_magnitude, b_magnitude);
541+
}
542+
}
543+
465544
ATOMIC_INTEGER_IMPL(Min)
466545
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
467546
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
@@ -476,6 +555,13 @@ inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
476555
});
477556
}
478557

558+
inline __device__ c10::complex<float> gpuAtomicMin(c10::complex<float> * address, c10::complex<float> val) {
559+
return AtomicFPOp<c10::complex<float>>()(address, val,
560+
[](c10::complex<float> bsum, c10::complex<float> val) {
561+
return complex_min(bsum, val);
562+
});
563+
}
564+
479565
inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
480566
return AtomicFPOp<at::BFloat16>()(address, val,
481567
[](at::BFloat16 bsum, at::BFloat16 val) {

aten/src/ATen/native/cuda/ScatterGatherKernel.cu

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <ATen/core/Tensor.h>
55
#include <ATen/Dispatch.h>
66
#include <ATen/MemoryOverlap.h>
7-
87
#include <ATen/native/ScatterGatherChecks.h>
98
#include <ATen/native/ReduceOpsUtils.h>
109
#include <ATen/native/TensorIterator.h>
@@ -201,7 +200,6 @@ struct cuda_scatter_gather_base_kernel {
201200
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
202201
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
203202

204-
205203
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
206204
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
207205
iter.dtype(),
@@ -259,7 +257,6 @@ struct cuda_scatter_gather_base_kernel {
259257
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
260258
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
261259

262-
263260
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
264261
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
265262
iter.dtype(),
@@ -318,9 +315,9 @@ struct cuda_scatter_gather_base_kernel {
318315
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
319316
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
320317

321-
322-
AT_DISPATCH_ALL_TYPES_AND2(
318+
AT_DISPATCH_ALL_TYPES_AND3(
323319
at::ScalarType::Half, at::ScalarType::BFloat16,
320+
at::ScalarType::ComplexFloat,
324321
iter.dtype(),
325322
"cuda_scatter_gather_base_kernel_func", [&] {
326323
using dtype = typename std::conditional<cast_to_opaque,
@@ -450,8 +447,9 @@ struct cuda_scatter_fill_base_kernel {
450447
auto index_size = ensure_nonempty_size(self, dim);
451448
auto index_stride = ensure_nonempty_stride(self, dim);
452449

453-
AT_DISPATCH_ALL_TYPES_AND2(
450+
AT_DISPATCH_ALL_TYPES_AND3(
454451
at::ScalarType::Half, at::ScalarType::BFloat16,
452+
at::ScalarType::ComplexFloat,
455453
iter.dtype(),
456454
"cuda_scatter_fill_base_kernel_reduce_multiply", [&] {
457455
using dtype = typename std::conditional<cast_to_opaque,

test/test_scatter_gather_ops.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,17 @@ def test_scatter_reduce_sum(self, device, dtype):
221221
include_self=include_self)
222222

223223
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True))
224-
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
224+
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
225+
include_complex=False, include_bool=False))
225226
def test_scatter_reduce_prod(self, device, dtype):
226227
for include_self in (True, False):
227228
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
228229
is_scalar=False, reduction='prod', unique_indices=False,
229230
include_self=include_self)
230231

231232
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
232-
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
233+
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
234+
include_complex=False, include_bool=False))
233235
def test_scatter_reduce_mean(self, device, dtype):
234236
for include_self in (True, False):
235237
for deterministic in [False, True]:
@@ -239,7 +241,8 @@ def test_scatter_reduce_mean(self, device, dtype):
239241
include_self=include_self)
240242

241243
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
242-
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
244+
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
245+
include_complex=False, include_bool=False))
243246
def test_scatter_reduce_amax(self, device, dtype):
244247
for include_self in (True, False):
245248
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
@@ -258,7 +261,8 @@ def test_scatter_reduce_amax(self, device, dtype):
258261

259262

260263
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
261-
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
264+
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
265+
include_complex=False, include_bool=False))
262266
def test_scatter_reduce_amin(self, device, dtype):
263267
for include_self in (True, False):
264268
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,

test/test_torch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@
5757
_create_scaling_case, _create_scaling_models_optimizers)
5858
from torch.testing._internal.common_mkldnn import bf32_on_and_off
5959
from torch.testing._internal.common_dtype import (
60-
floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types,
61-
all_types_and, floating_types, floating_and_complex_types, integral_types_and,
60+
floating_types_and, get_all_math_dtypes, all_types_and_complex_and, all_types_and, floating_types,
61+
floating_and_complex_types, integral_types_and,
6262
get_all_qint_dtypes,
6363
)
6464
from torch.testing._internal.two_tensor import TwoTensor
@@ -3837,7 +3837,7 @@ def test_scatter_reduce_non_unique_index(self, device, dtype):
38373837
self.assertEqual(input, result, msg=f"result: {result} input: {input} method: {str(operation)}")
38383838

38393839
@onlyCUDA
3840-
@dtypes(*complex_types())
3840+
@dtypes(torch.cdouble)
38413841
def test_scatter_reduce_multiply_unsupported_dtypes(self, device, dtype):
38423842
height = 2
38433843
width = 2

0 commit comments

Comments
 (0)