55#include < cuda_bf16.h>
66#endif
77
8+ // ROCm 6.3 is planned to have these functions, but until then here they are.
9+ #if defined(USE_ROCM) && ROCM_VERSION >= 60201
10+ #include < hip/hip_bf16.h>
11+ #include < hip/hip_fp16.h>
12+
13+ __device__ inline __hip_bfloat162 preview_unsafeAtomicAdd (__hip_bfloat162* address, __hip_bfloat162 value) {
14+ #if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && \
15+ __has_builtin (__builtin_amdgcn_flat_atomic_fadd_v2bf16)
16+ typedef unsigned short __attribute__ ((ext_vector_type (2 ))) vec_short2;
17+ static_assert (sizeof (vec_short2) == sizeof (__hip_bfloat162_raw));
18+ union {
19+ __hip_bfloat162_raw bf162_raw;
20+ vec_short2 vs2;
21+ } u{static_cast <__hip_bfloat162_raw>(value)};
22+ u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16 ((vec_short2*)address, u.vs2 );
23+ return static_cast <__hip_bfloat162>(u.bf162_raw );
24+ #else
25+ static_assert (sizeof (unsigned int ) == sizeof (__hip_bfloat162_raw));
26+ union u_hold {
27+ __hip_bfloat162_raw h2r;
28+ unsigned int u32 ;
29+ };
30+ u_hold old_val, new_val;
31+ old_val.u32 = __hip_atomic_load ((unsigned int *)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
32+ do {
33+ new_val.h2r = __hadd2 (old_val.h2r , value);
34+ } while (!__hip_atomic_compare_exchange_strong (
35+ (unsigned int *)address, &old_val.u32 , new_val.u32 ,
36+ __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
37+ return old_val.h2r ;
38+ #endif
39+ }
40+
41+ __device__ inline __half2 preview_unsafeAtomicAdd (__half2* address, __half2 value) {
42+ #if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && \
43+ __has_builtin (__builtin_amdgcn_flat_atomic_fadd_v2f16)
44+ // The api expects an ext_vector_type of half
45+ typedef _Float16 __attribute__ ((ext_vector_type (2 ))) vec_fp162;
46+ static_assert (sizeof (vec_fp162) == sizeof (__half2_raw));
47+ union {
48+ __half2_raw h2r;
49+ vec_fp162 fp16;
50+ } u {static_cast <__half2_raw>(value)};
51+ u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16 ((vec_fp162*)address, u.fp16 );
52+ return static_cast <__half2>(u.h2r );
53+ #else
54+ static_assert (sizeof (__half2_raw) == sizeof (unsigned int ));
55+ union u_hold {
56+ __half2_raw h2r;
57+ unsigned int u32 ;
58+ };
59+ u_hold old_val, new_val;
60+ old_val.u32 = __hip_atomic_load ((unsigned int *)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
61+ do {
62+ new_val.h2r = __hadd2 (old_val.h2r , value);
63+ } while (!__hip_atomic_compare_exchange_strong (
64+ (unsigned int *)address, &old_val.u32 , new_val.u32 ,
65+ __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
66+ return old_val.h2r ;
67+ #endif
68+ }
69+ #define ATOMICADD preview_unsafeAtomicAdd
70+ #define NATIVE_ZERO_BF16 __float2bfloat16 (0 .0f )
71+ #else
72+ #define ATOMICADD atomicAdd
73+ #define NATIVE_ZERO_BF16 __int2bfloat16_rz (0 )
74+ #endif
75+
876namespace at :: native {
977
1078__device__ __forceinline__ size_t
@@ -47,7 +115,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
47115 const index_t numel,
48116 scalar_t value) {
49117#if ( \
50- (defined (USE_ROCM)) || \
118+ (defined (USE_ROCM) && ROCM_VERSION < 60201 ) || \
51119 (defined (__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 )))
52120 gpuAtomicAddNoReturn (
53121 reinterpret_cast <at::Half*>(tensor) + index,
@@ -61,17 +129,22 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
61129 __half2 value2;
62130 value2.x = static_cast <__half>(value);
63131 value2.y = __int2half_rz (0 );
64- atomicAdd (reinterpret_cast <__half2*>(target_addr), value2);
132+ ATOMICADD (reinterpret_cast <__half2*>(target_addr), value2);
65133
66134 } else if (!low_byte && index > 0 ) {
67135 __half2 value2;
68136 value2.x = __int2half_rz (0 );
69137 value2.y = static_cast <__half>(value);
70- atomicAdd (reinterpret_cast <__half2*>(target_addr - 1 ), value2);
138+ ATOMICADD (reinterpret_cast <__half2*>(target_addr - 1 ), value2);
71139
72140 } else {
141+ #ifdef USE_ROCM
142+ gpuAtomicAddNoReturn (
143+ reinterpret_cast <at::Half*>(tensor) + index, static_cast <at::Half>(value));
144+ #else
73145 atomicAdd (
74146 reinterpret_cast <__half*>(tensor) + index, static_cast <__half>(value));
147+ #endif
75148 }
76149#endif
77150}
@@ -87,7 +160,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
87160 const index_t numel,
88161 scalar_t value) {
89162#if ( \
90- (defined (USE_ROCM)) || \
163+ (defined (USE_ROCM) && ROCM_VERSION < 60201 ) || \
91164 (defined (__CUDA_ARCH__) && (__CUDA_ARCH__ < 800 )))
92165 gpuAtomicAddNoReturn (
93166 reinterpret_cast <at::BFloat16*>(tensor) + index,
@@ -100,18 +173,23 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
100173 if (low_byte && index < (numel - 1 )) {
101174 __nv_bfloat162 value2;
102175 value2.x = *reinterpret_cast <__nv_bfloat16*>(&value);
103- value2.y = __int2bfloat16_rz ( 0 ) ;
104- atomicAdd (reinterpret_cast <__nv_bfloat162*>(target_addr), value2);
176+ value2.y = NATIVE_ZERO_BF16 ;
177+ ATOMICADD (reinterpret_cast <__nv_bfloat162*>(target_addr), value2);
105178
106179 } else if (!low_byte && index > 0 ) {
107180 __nv_bfloat162 value2;
108- value2.x = __int2bfloat16_rz ( 0 ) ;
181+ value2.x = NATIVE_ZERO_BF16 ;
109182 value2.y = *reinterpret_cast <__nv_bfloat16*>(&value);
110- atomicAdd (reinterpret_cast <__nv_bfloat162*>(target_addr - 1 ), value2);
183+ ATOMICADD (reinterpret_cast <__nv_bfloat162*>(target_addr - 1 ), value2);
111184
112185 } else {
186+ #ifdef USE_ROCM
187+ gpuAtomicAddNoReturn (
188+ reinterpret_cast <at::BFloat16*>(tensor) + index, static_cast <at::BFloat16>(value));
189+ #else
113190 atomicAdd (
114191 reinterpret_cast <__nv_bfloat16*>(tensor) + index, *reinterpret_cast <__nv_bfloat16*>(&value));
192+ #endif
115193 }
116194#endif
117195}
@@ -144,4 +222,7 @@ __device__ __forceinline__ void fastAtomicAdd(
144222 }
145223}
146224
225+ #undef ATOMICADD
226+ #undef NATIVE_ZERO_BF16
227+
147228} // namespace at::native
0 commit comments