Skip to content

Commit a0e2f62

Browse files
Revert "Include support for the scatter gather cuda kernels to allow for comp… (#124809)"
This reverts commit 9e24c26. Reverted #124809 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](#124809 (comment)))
1 parent b1b0399 commit a0e2f62

File tree

5 files changed

+14
-106
lines changed

5 files changed

+14
-106
lines changed

aten/src/ATen/NumericUtils.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@ inline C10_HOST_DEVICE bool _isnan(T val) {
3838

3939
template <typename T, std::enable_if_t<c10::is_complex<T>::value, int> = 0>
4040
inline C10_HOST_DEVICE bool _isnan(T val) {
41-
#if defined(__CUDACC__) || defined(__HIPCC__)
42-
return ::isnan(val.real()) || ::isnan(val.imag());
43-
#else
4441
return std::isnan(val.real()) || std::isnan(val.imag());
45-
#endif
4642
}
4743

4844
template <typename T, std::enable_if_t<std::is_same_v<T, at::Half>, int> = 0>

aten/src/ATen/cuda/Atomic.cuh

Lines changed: 1 addition & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,6 @@ struct AtomicFPOp<at::Half> {
3535
}
3636
};
3737

38-
template <>
39-
struct AtomicFPOp<c10::complex<float>> {
40-
template <typename func_t>
41-
inline __device__ c10::complex<float> operator() (c10::complex<float> *address, c10::complex<float> val, const func_t& func) {
42-
unsigned long long int* addr_as_ull = (unsigned long long int*)address;
43-
unsigned long long int old = *addr_as_ull;
44-
unsigned long long int assumed, new_val;
45-
46-
c10::complex<float> csum;
47-
do {
48-
assumed = old;
49-
csum = func(csum, val);
50-
new_val = *reinterpret_cast<unsigned long long*>(&csum);
51-
old = atomicCAS(addr_as_ull, assumed, new_val);
52-
} while (assumed != old);
53-
54-
return *reinterpret_cast<c10::complex<float>*>(&addr_as_ull);
55-
}
56-
};
57-
5838
template <>
5939
struct AtomicFPOp<at::BFloat16> {
6040
template <typename func_t>
@@ -368,14 +348,6 @@ GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
368348
GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
369349
GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)
370350

371-
inline __device__ c10::complex<float> gpuAtomicMul(c10::complex<float> *address, c10::complex<float> val){
372-
return AtomicFPOp<c10::complex<float>>()(address, val,
373-
[](c10::complex<float> bsum, c10::complex<float> val) {
374-
bsum*=(val);
375-
return bsum;
376-
});
377-
}
378-
379351
inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
380352
return AtomicFPOp<at::Half>()(address, val,
381353
[](at::Half bsum, at::Half val) {
@@ -397,7 +369,7 @@ inline __device__ double gpuAtomicMul(double * address, double val) {
397369
});
398370
}
399371

400-
// Don't use a templated function for this since the addition function defaults to the CUDA built-in.
372+
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
401373
inline __device__ float gpuAtomicMul (float * address, float val) {
402374
unsigned int* address_as_ull = (unsigned int*)address;
403375
unsigned int old = *address_as_ull;
@@ -430,29 +402,6 @@ __host__ __device__ T safe_max(T a, T b) {
430402
return max;
431403
}
432404

433-
__inline__ __device__ c10::complex<float> complex_max(c10::complex<float> a, c10::complex<float> b) {
434-
if(at::_isnan(b)) {
435-
return b;
436-
} else {
437-
// Compute the magnitude of the complex numbers and compare each to see which one is greater.
438-
float a_magnitude = __fsqrt_rn(
439-
(
440-
__fmul_rn(a.real(), a.real()) +
441-
__fmul_rn(a.imag(),a.imag())
442-
)
443-
);
444-
float b_magnitude = __fsqrt_rn(
445-
(
446-
__fmul_rn(b.real(), b.real()) +
447-
__fmul_rn(b.imag(),b.imag())
448-
)
449-
);
450-
return std::max<float>(a_magnitude, b_magnitude);
451-
}
452-
453-
}
454-
455-
456405
ATOMIC_INTEGER_IMPL(Max)
457406
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
458407
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
@@ -467,13 +416,6 @@ inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
467416
});
468417
}
469418

470-
inline __device__ c10::complex<float> gpuAtomicMax(c10::complex<float> * address, c10::complex<float> val) {
471-
return AtomicFPOp<c10::complex<float>>()(address, val,
472-
[](c10::complex<float> bsum, c10::complex<float> val) {
473-
return complex_max(bsum, val);
474-
});
475-
}
476-
477419
inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
478420
return AtomicFPOp<at::BFloat16>()(address, val,
479421
[](at::BFloat16 bsum, at::BFloat16 val) {
@@ -520,27 +462,6 @@ __host__ __device__ T safe_min(T a, T b) {
520462
return min;
521463
}
522464

523-
__inline__ __device__ c10::complex<float> complex_min(c10::complex<float> a, c10::complex<float> b) {
524-
if(at::_isnan(b)) {
525-
return b;
526-
} else {
527-
// Compute the magnitude of the complex numbers and compare each to see which one is smaller.
528-
float a_magnitude = __fsqrt_rn(
529-
(
530-
__fmul_rn(a.real(), a.real()) +
531-
__fmul_rn(a.imag(),a.imag())
532-
)
533-
);
534-
float b_magnitude = __fsqrt_rn(
535-
(
536-
__fmul_rn(b.real(), b.real()) +
537-
__fmul_rn(b.imag(),b.imag())
538-
)
539-
);
540-
return std::min<float>(a_magnitude, b_magnitude);
541-
}
542-
}
543-
544465
ATOMIC_INTEGER_IMPL(Min)
545466
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
546467
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
@@ -555,13 +476,6 @@ inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
555476
});
556477
}
557478

558-
inline __device__ c10::complex<float> gpuAtomicMin(c10::complex<float> * address, c10::complex<float> val) {
559-
return AtomicFPOp<c10::complex<float>>()(address, val,
560-
[](c10::complex<float> bsum, c10::complex<float> val) {
561-
return complex_min(bsum, val);
562-
});
563-
}
564-
565479
inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
566480
return AtomicFPOp<at::BFloat16>()(address, val,
567481
[](at::BFloat16 bsum, at::BFloat16 val) {

aten/src/ATen/native/cuda/ScatterGatherKernel.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <ATen/core/Tensor.h>
55
#include <ATen/Dispatch.h>
66
#include <ATen/MemoryOverlap.h>
7+
78
#include <ATen/native/ScatterGatherChecks.h>
89
#include <ATen/native/ReduceOpsUtils.h>
910
#include <ATen/native/TensorIterator.h>
@@ -200,6 +201,7 @@ struct cuda_scatter_gather_base_kernel {
200201
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
201202
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
202203

204+
203205
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
204206
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
205207
iter.dtype(),
@@ -257,6 +259,7 @@ struct cuda_scatter_gather_base_kernel {
257259
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
258260
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
259261

262+
260263
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
261264
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
262265
iter.dtype(),
@@ -315,9 +318,9 @@ struct cuda_scatter_gather_base_kernel {
315318
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
316319
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
317320

318-
AT_DISPATCH_ALL_TYPES_AND3(
321+
322+
AT_DISPATCH_ALL_TYPES_AND2(
319323
at::ScalarType::Half, at::ScalarType::BFloat16,
320-
at::ScalarType::ComplexFloat,
321324
iter.dtype(),
322325
"cuda_scatter_gather_base_kernel_func", [&] {
323326
using dtype = typename std::conditional<cast_to_opaque,
@@ -447,9 +450,8 @@ struct cuda_scatter_fill_base_kernel {
447450
auto index_size = ensure_nonempty_size(self, dim);
448451
auto index_stride = ensure_nonempty_stride(self, dim);
449452

450-
AT_DISPATCH_ALL_TYPES_AND3(
453+
AT_DISPATCH_ALL_TYPES_AND2(
451454
at::ScalarType::Half, at::ScalarType::BFloat16,
452-
at::ScalarType::ComplexFloat,
453455
iter.dtype(),
454456
"cuda_scatter_fill_base_kernel_reduce_multiply", [&] {
455457
using dtype = typename std::conditional<cast_to_opaque,

test/test_scatter_gather_ops.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -221,17 +221,15 @@ def test_scatter_reduce_sum(self, device, dtype):
221221
include_self=include_self)
222222

223223
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True))
224-
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
225-
include_complex=False, include_bool=False))
224+
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
226225
def test_scatter_reduce_prod(self, device, dtype):
227226
for include_self in (True, False):
228227
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
229228
is_scalar=False, reduction='prod', unique_indices=False,
230229
include_self=include_self)
231230

232231
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
233-
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
234-
include_complex=False, include_bool=False))
232+
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
235233
def test_scatter_reduce_mean(self, device, dtype):
236234
for include_self in (True, False):
237235
for deterministic in [False, True]:
@@ -241,8 +239,7 @@ def test_scatter_reduce_mean(self, device, dtype):
241239
include_self=include_self)
242240

243241
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
244-
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
245-
include_complex=False, include_bool=False))
242+
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
246243
def test_scatter_reduce_amax(self, device, dtype):
247244
for include_self in (True, False):
248245
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
@@ -261,8 +258,7 @@ def test_scatter_reduce_amax(self, device, dtype):
261258

262259

263260
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
264-
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex32=True,
265-
include_complex=False, include_bool=False))
261+
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
266262
def test_scatter_reduce_amin(self, device, dtype):
267263
for include_self in (True, False):
268264
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,

test/test_torch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@
5757
_create_scaling_case, _create_scaling_models_optimizers)
5858
from torch.testing._internal.common_mkldnn import bf32_on_and_off
5959
from torch.testing._internal.common_dtype import (
60-
floating_types_and, get_all_math_dtypes, all_types_and_complex_and, all_types_and, floating_types,
61-
floating_and_complex_types, integral_types_and,
60+
floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types,
61+
all_types_and, floating_types, floating_and_complex_types, integral_types_and,
6262
get_all_qint_dtypes,
6363
)
6464
from torch.testing._internal.two_tensor import TwoTensor
@@ -3837,7 +3837,7 @@ def test_scatter_reduce_non_unique_index(self, device, dtype):
38373837
self.assertEqual(input, result, msg=f"result: {result} input: {input} method: {str(operation)}")
38383838

38393839
@onlyCUDA
3840-
@dtypes(torch.cdouble)
3840+
@dtypes(*complex_types())
38413841
def test_scatter_reduce_multiply_unsupported_dtypes(self, device, dtype):
38423842
height = 2
38433843
width = 2

0 commit comments

Comments
 (0)