@@ -35,6 +35,26 @@ struct AtomicFPOp<at::Half> {
3535 }
3636};
3737
38+ template <>
39+ struct AtomicFPOp <c10::complex <float >> {
40+ template <typename func_t >
41+ inline __device__ c10::complex <float > operator () (c10::complex <float > *address, c10::complex <float > val, const func_t & func) {
42+ unsigned long long int * addr_as_ull = (unsigned long long int *)address;
43+ unsigned long long int old = *addr_as_ull;
44+ unsigned long long int assumed, new_val;
45+
46+ c10::complex <float > csum;
47+ do {
48+ assumed = old;
49+ csum = func (csum, val);
50+ new_val = *reinterpret_cast <unsigned long long *>(&csum);
51+ old = atomicCAS (addr_as_ull, assumed, new_val);
52+ } while (assumed != old);
53+
54+ return *reinterpret_cast <c10::complex <float >*>(&addr_as_ull);
55+ }
56+ };
57+
3858template <>
3959struct AtomicFPOp <at::BFloat16> {
4060 template <typename func_t >
@@ -348,6 +368,14 @@ GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
348368GPU_ATOMIC_INTEGER(Mul, a * b, int32_t )
349369GPU_ATOMIC_INTEGER(Mul, a * b, int64_t )
350370
371+ inline __device__ c10::complex<float> gpuAtomicMul(c10::complex <float > *address, c10::complex <float > val){
372+ return AtomicFPOp<c10::complex <float >>()(address, val,
373+ [](c10::complex <float > bsum, c10::complex <float > val) {
374+ bsum*=(val);
375+ return bsum;
376+ });
377+ }
378+
351379inline __device__ at::Half gpuAtomicMul (at::Half * address, at::Half val) {
352380 return AtomicFPOp<at::Half>()(address, val,
353381 [](at::Half bsum, at::Half val) {
@@ -369,7 +397,7 @@ inline __device__ double gpuAtomicMul(double * address, double val) {
369397 });
370398}
371399
372- // Dont use a templated function for this since the addition function defaults to the CUDA built-in.
400+ // Don't use a templated function for this since the addition function defaults to the CUDA built-in.
373401inline __device__ float gpuAtomicMul (float * address, float val) {
374402 unsigned int * address_as_ull = (unsigned int *)address;
375403 unsigned int old = *address_as_ull;
@@ -402,6 +430,29 @@ __host__ __device__ T safe_max(T a, T b) {
402430 return max;
403431}
404432
433+ __inline__ __device__ c10::complex <float > complex_max (c10::complex <float > a, c10::complex <float > b) {
434+ if (at::_isnan (b)) {
435+ return b;
436+ } else {
437+ // Compute the magnitude of the complex numbers and compare each to see which one is greater.
438+ float a_magnitude = __fsqrt_rn (
439+ (
440+ __fmul_rn (a.real (), a.real ()) +
441+ __fmul_rn (a.imag (),a.imag ())
442+ )
443+ );
444+ float b_magnitude = __fsqrt_rn (
445+ (
446+ __fmul_rn (b.real (), b.real ()) +
447+ __fmul_rn (b.imag (),b.imag ())
448+ )
449+ );
450+ return std::max<float >(a_magnitude, b_magnitude);
451+ }
452+
453+ }
454+
455+
405456ATOMIC_INTEGER_IMPL (Max)
406457GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
407458GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
@@ -416,6 +467,13 @@ inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
416467 });
417468}
418469
470+ inline __device__ c10::complex <float > gpuAtomicMax (c10::complex <float > * address, c10::complex <float > val) {
471+ return AtomicFPOp<c10::complex <float >>()(address, val,
472+ [](c10::complex <float > bsum, c10::complex <float > val) {
473+ return complex_max (bsum, val);
474+ });
475+ }
476+
419477inline __device__ at::BFloat16 gpuAtomicMax (at::BFloat16 * address, at::BFloat16 val) {
420478 return AtomicFPOp<at::BFloat16>()(address, val,
421479 [](at::BFloat16 bsum, at::BFloat16 val) {
@@ -462,6 +520,27 @@ __host__ __device__ T safe_min(T a, T b) {
462520 return min;
463521}
464522
523+ __inline__ __device__ c10::complex <float > complex_min (c10::complex <float > a, c10::complex <float > b) {
524+ if (at::_isnan (b)) {
525+ return b;
526+ } else {
527+ // Compute the magnitude of the complex numbers and compare each to see which one is smaller.
528+ float a_magnitude = __fsqrt_rn (
529+ (
530+ __fmul_rn (a.real (), a.real ()) +
531+ __fmul_rn (a.imag (),a.imag ())
532+ )
533+ );
534+ float b_magnitude = __fsqrt_rn (
535+ (
536+ __fmul_rn (b.real (), b.real ()) +
537+ __fmul_rn (b.imag (),b.imag ())
538+ )
539+ );
540+ return std::min<float >(a_magnitude, b_magnitude);
541+ }
542+ }
543+
465544ATOMIC_INTEGER_IMPL (Min)
466545GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
467546GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
@@ -476,6 +555,13 @@ inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
476555 });
477556}
478557
558+ inline __device__ c10::complex <float > gpuAtomicMin (c10::complex <float > * address, c10::complex <float > val) {
559+ return AtomicFPOp<c10::complex <float >>()(address, val,
560+ [](c10::complex <float > bsum, c10::complex <float > val) {
561+ return complex_min (bsum, val);
562+ });
563+ }
564+
479565inline __device__ at::BFloat16 gpuAtomicMin (at::BFloat16 * address, at::BFloat16 val) {
480566 return AtomicFPOp<at::BFloat16>()(address, val,
481567 [](at::BFloat16 bsum, at::BFloat16 val) {
0 commit comments