Skip to content

Commit 2435d94

Browse files
Michael Carillifacebook-github-bot
authored andcommitted
Fix FP16 fastAtomicAdd for one case where tensor start address is not 32 bit aligned (#44642)
Summary: For #44206 and #42218, I'd like to update trilinear interpolate backward and grid_sample backward to use `fastAtomicAdd`. As a prelude, I spotted a UB risk in `fastAtomicAdd`. I think existing code incurs a misaligned `__half2` atomicAdd when `index` is odd and `tensor` is not 32-bit aligned (`index % 2 == 1` and `(reinterpret_cast<std::uintptr_t>(tensor) % sizeof(__half2) == 1`). In this case we think we're `!low_bit` and go down the `!low_bit` code path, but in fact we are `low_bit`. It appears the original [fastAtomicAdd PR](#21879 (comment) discussion did not consider that case explicitly. I wanted to push my tentative fix for discussion ASAP. jjsjann123 and mkolod as original authors of `fastAtomicAdd`. (I'm also curious why we need to `reinterpret_cast<std::uintptr_t>(tensor...` for the address modding, but that's minor.) Pull Request resolved: #44642 Reviewed By: mruberry Differential Revision: D23699820 Pulled By: ngimel fbshipit-source-id: 0db57150715ebb45e6a1fb36897e46f00d61defd
1 parent 2fd142a commit 2435d94

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

aten/src/ATen/native/cuda/KernelUtils.cuh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,21 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
2121
reinterpret_cast<at::Half*>(tensor) + index,
2222
static_cast<at::Half>(value));
2323
#else
24-
bool low_bit = (index % 2 == 0) &&
25-
(reinterpret_cast<std::uintptr_t>(tensor) % sizeof(__half2) == 0);
24+
// Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned)
25+
__half* target_addr = reinterpret_cast<__half*>(tensor + index);
26+
bool low_byte = (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__half2) == 0);
2627

27-
if (low_bit && index < (numel - 1)) {
28+
if (low_byte && index < (numel - 1)) {
2829
__half2 value2;
2930
value2.x = value;
3031
value2.y = __int2half_rz(0);
31-
atomicAdd(reinterpret_cast<__half2*>(tensor) + index / 2, value2);
32+
atomicAdd(reinterpret_cast<__half2*>(target_addr), value2);
3233

33-
} else if (!low_bit && index > 0) {
34+
} else if (!low_byte && index > 0) {
3435
__half2 value2;
3536
value2.x = __int2half_rz(0);
3637
value2.y = value;
37-
atomicAdd(reinterpret_cast<__half2*>(tensor) + index / 2, value2);
38+
atomicAdd(reinterpret_cast<__half2*>(target_addr - 1), value2);
3839

3940
} else {
4041
atomicAdd(

0 commit comments

Comments
 (0)