Skip to content

Commit a1fe323

Browse files
ZelboKpytorchmergebot
authored andcommitted
Include support for the scatter gather cuda kernels to allow for complex<float>. Initial foundation
1 parent 07d3af8 commit a1fe323

File tree

5 files changed

+140
-11
lines changed

5 files changed

+140
-11
lines changed

CMakeCache.txt

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# This is the CMakeCache file.
2+
# For build in directory: /home/ksm/pytorch
3+
# It was generated by CMake: /home/ksm/anaconda3/envs/pyt_dev/bin/cmake
4+
# You can edit this file to change values found and used by cmake.
5+
# If you do not want to change any of the values, simply exit the editor.
6+
# If you do want to change a value, simply edit, save, and exit the editor.
7+
# The syntax for the file is as follows:
8+
# KEY:TYPE=VALUE
9+
# KEY is the name of a variable in the cache.
10+
# TYPE is a hint to GUIs for the type of VALUE, DO NOT EDIT TYPE!.
11+
# VALUE is the current value for the KEY.
12+
13+
########################
14+
# EXTERNAL cache entries
15+
########################
16+
17+
18+
########################
19+
# INTERNAL cache entries
20+
########################
21+
22+
//This is the directory where this CMakeCache.txt was created
23+
CMAKE_CACHEFILE_DIR:INTERNAL=/home/ksm/pytorch
24+
//Major version of cmake used to create the current loaded cache
25+
CMAKE_CACHE_MAJOR_VERSION:INTERNAL=3
26+
//Minor version of cmake used to create the current loaded cache
27+
CMAKE_CACHE_MINOR_VERSION:INTERNAL=26
28+
//Patch version of cmake used to create the current loaded cache
29+
CMAKE_CACHE_PATCH_VERSION:INTERNAL=4
30+
//Path to CMake executable.
31+
CMAKE_COMMAND:INTERNAL=/home/ksm/anaconda3/envs/pyt_dev/bin/cmake
32+
//Path to cpack program executable.
33+
CMAKE_CPACK_COMMAND:INTERNAL=/home/ksm/anaconda3/envs/pyt_dev/bin/cpack
34+
//Path to ctest program executable.
35+
CMAKE_CTEST_COMMAND:INTERNAL=/home/ksm/anaconda3/envs/pyt_dev/bin/ctest
36+
//Path to CMake installation.
37+
CMAKE_ROOT:INTERNAL=/home/ksm/anaconda3/envs/pyt_dev/share/cmake-3.26
38+

CMakeFiles/cmake.check_cache

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# This file is generated by cmake for dependency checking of the CMakeCache.txt file

aten/src/ATen/cuda/Atomic.cuh

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,26 @@ struct AtomicFPOp<at::Half> {
3535
}
3636
};
3737

38+
template <>
39+
struct AtomicFPOp<c10::complex<float>> {
40+
template <typename func_t>
41+
inline __device__ c10::complex<float> operator() (c10::complex<float> *address, c10::complex<float> val, const func_t& func) {
42+
unsigned long long int* addr_as_ull = (unsigned long long int*)address;
43+
unsigned long long int old = *addr_as_ull;
44+
unsigned long long int assumed, new_val;
45+
46+
c10::complex<float> csum;
47+
do {
48+
assumed = old;
49+
csum = func(csum, val);
50+
new_val = *reinterpret_cast<unsigned long long*>(&csum);
51+
old = atomicCAS(addr_as_ull, assumed, new_val);
52+
} while (assumed != old);
53+
54+
return *reinterpret_cast<c10::complex<float>*>(&addr_as_ull);
55+
}
56+
};
57+
3858
template <>
3959
struct AtomicFPOp<at::BFloat16> {
4060
template <typename func_t>
@@ -348,6 +368,14 @@ GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
348368
GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
349369
GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)
350370

371+
inline __device__ c10::complex<float> gpuAtomicMul(c10::complex<float> *address, c10::complex<float> val){
372+
return AtomicFPOp<c10::complex<float>>()(address, val,
373+
[](c10::complex<float> bsum, c10::complex<float> val) {
374+
bsum*=(val);
375+
return bsum;
376+
});
377+
}
378+
351379
inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
352380
return AtomicFPOp<at::Half>()(address, val,
353381
[](at::Half bsum, at::Half val) {
@@ -369,7 +397,7 @@ inline __device__ double gpuAtomicMul(double * address, double val) {
369397
});
370398
}
371399

372-
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
400+
// Don't use a templated function for this since the addition function defaults to the CUDA built-in.
373401
inline __device__ float gpuAtomicMul (float * address, float val) {
374402
unsigned int* address_as_ull = (unsigned int*)address;
375403
unsigned int old = *address_as_ull;
@@ -402,6 +430,29 @@ __host__ __device__ T safe_max(T a, T b) {
402430
return max;
403431
}
404432

433+
__inline__ __device__ c10::complex<float> complex_max(c10::complex<float> a, c10::complex<float> b) {
434+
if(at::_isnan(b)) {
435+
return b;
436+
} else {
437+
// Compute the magnitude of the complex numbers and compare each to see which one is greater.
438+
float a_magnitude = __fsqrt_rn(
439+
(
440+
__fmul_rn(a.real(), a.real()) +
441+
__fmul_rn(a.imag(),a.imag())
442+
)
443+
);
444+
float b_magnitude = __fsqrt_rn(
445+
(
446+
__fmul_rn(b.real(), b.real()) +
447+
__fmul_rn(b.imag(),b.imag())
448+
)
449+
);
450+
return std::max<float>(a_magnitude, b_magnitude);
451+
}
452+
453+
}
454+
455+
405456
ATOMIC_INTEGER_IMPL(Max)
406457
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
407458
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
@@ -416,6 +467,13 @@ inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
416467
});
417468
}
418469

470+
inline __device__ c10::complex<float> gpuAtomicMax(c10::complex<float> * address, c10::complex<float> val) {
471+
return AtomicFPOp<c10::complex<float>>()(address, val,
472+
[](c10::complex<float> bsum, c10::complex<float> val) {
473+
return complex_max(bsum, val);
474+
});
475+
}
476+
419477
inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
420478
return AtomicFPOp<at::BFloat16>()(address, val,
421479
[](at::BFloat16 bsum, at::BFloat16 val) {
@@ -462,6 +520,27 @@ __host__ __device__ T safe_min(T a, T b) {
462520
return min;
463521
}
464522

523+
__inline__ __device__ c10::complex<float> complex_min(c10::complex<float> a, c10::complex<float> b) {
524+
if(at::_isnan(b)) {
525+
return b;
526+
} else {
527+
// Compute the magnitude of the complex numbers and compare each to see which one is smaller.
528+
float a_magnitude = __fsqrt_rn(
529+
(
530+
__fmul_rn(a.real(), a.real()) +
531+
__fmul_rn(a.imag(),a.imag())
532+
)
533+
);
534+
float b_magnitude = __fsqrt_rn(
535+
(
536+
__fmul_rn(b.real(), b.real()) +
537+
__fmul_rn(b.imag(),b.imag())
538+
)
539+
);
540+
return std::min<float>(a_magnitude, b_magnitude);
541+
}
542+
}
543+
465544
ATOMIC_INTEGER_IMPL(Min)
466545
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
467546
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
@@ -476,6 +555,13 @@ inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
476555
});
477556
}
478557

558+
inline __device__ c10::complex<float> gpuAtomicMin(c10::complex<float> * address, c10::complex<float> val) {
559+
return AtomicFPOp<c10::complex<float>>()(address, val,
560+
[](c10::complex<float> bsum, c10::complex<float> val) {
561+
return complex_min(bsum, val);
562+
});
563+
}
564+
479565
inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
480566
return AtomicFPOp<at::BFloat16>()(address, val,
481567
[](at::BFloat16 bsum, at::BFloat16 val) {

aten/src/ATen/native/cuda/ScatterGatherKernel.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <ATen/core/Tensor.h>
55
#include <ATen/Dispatch.h>
66
#include <ATen/MemoryOverlap.h>
7-
7+
#include <iostream>
88
#include <ATen/native/ScatterGatherChecks.h>
99
#include <ATen/native/ReduceOpsUtils.h>
1010
#include <ATen/native/TensorIterator.h>
@@ -201,7 +201,6 @@ struct cuda_scatter_gather_base_kernel {
201201
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
202202
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
203203

204-
205204
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
206205
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
207206
iter.dtype(),
@@ -259,7 +258,6 @@ struct cuda_scatter_gather_base_kernel {
259258
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
260259
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
261260

262-
263261
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
264262
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
265263
iter.dtype(),
@@ -318,9 +316,10 @@ struct cuda_scatter_gather_base_kernel {
318316
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
319317
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
320318

321-
322-
AT_DISPATCH_ALL_TYPES_AND2(
319+
// this should have complex in it
320+
AT_DISPATCH_ALL_TYPES_AND3(
323321
at::ScalarType::Half, at::ScalarType::BFloat16,
322+
at::ScalarType::ComplexFloat,
324323
iter.dtype(),
325324
"cuda_scatter_gather_base_kernel_func", [&] {
326325
using dtype = typename std::conditional<cast_to_opaque,
@@ -450,8 +449,9 @@ struct cuda_scatter_fill_base_kernel {
450449
auto index_size = ensure_nonempty_size(self, dim);
451450
auto index_stride = ensure_nonempty_stride(self, dim);
452451

453-
AT_DISPATCH_ALL_TYPES_AND2(
452+
AT_DISPATCH_ALL_TYPES_AND3(
454453
at::ScalarType::Half, at::ScalarType::BFloat16,
454+
at::ScalarType::ComplexFloat,
455455
iter.dtype(),
456456
"cuda_scatter_fill_base_kernel_reduce_multiply", [&] {
457457
using dtype = typename std::conditional<cast_to_opaque,

test/test_scatter_gather_ops.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,17 @@ def test_scatter_reduce_sum(self, device, dtype):
221221
include_self=include_self)
222222

223223
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True))
224-
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
224+
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
225+
include_complex=False, include_bool=False))
225226
def test_scatter_reduce_prod(self, device, dtype):
226227
for include_self in (True, False):
227228
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
228229
is_scalar=False, reduction='prod', unique_indices=False,
229230
include_self=include_self)
230231

231232
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
232-
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
233+
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
234+
include_complex=False, include_bool=False))
233235
def test_scatter_reduce_mean(self, device, dtype):
234236
for include_self in (True, False):
235237
for deterministic in [False, True]:
@@ -239,7 +241,8 @@ def test_scatter_reduce_mean(self, device, dtype):
239241
include_self=include_self)
240242

241243
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
242-
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
244+
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
245+
include_complex=False, include_bool=False))
243246
def test_scatter_reduce_amax(self, device, dtype):
244247
for include_self in (True, False):
245248
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
@@ -258,7 +261,8 @@ def test_scatter_reduce_amax(self, device, dtype):
258261

259262

260263
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
261-
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
264+
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
265+
include_complex=False, include_bool=False))
262266
def test_scatter_reduce_amin(self, device, dtype):
263267
for include_self in (True, False):
264268
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,

0 commit comments

Comments
 (0)