Skip to content

Commit fe0c8d5

Browse files
umar4569prady9
authored andcommitted
Fix problems in RNG for older compute architectures with fp16
1 parent dffa8c5 commit fe0c8d5

File tree

7 files changed

+197
-114
lines changed

7 files changed

+197
-114
lines changed

src/backend/common/half.hpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ AF_CONSTEXPR __DH__ static inline bool isinf(half val) noexcept;
822822
AF_CONSTEXPR __DH__ static inline bool isnan(common::half val) noexcept;
823823

824824
class alignas(2) half {
825-
native_half_t data_ = 0;
825+
native_half_t data_ = native_half_t();
826826

827827
#if !defined(NVCC) && !defined(__CUDACC_RTC__)
828828
// NVCC on OSX performs a weird transformation where it removes the std::
@@ -881,11 +881,19 @@ class alignas(2) half {
881881
return *this;
882882
}
883883

884-
#if defined(__CUDA_ARCH__)
885-
AF_CONSTEXPR __DH__ explicit half(const __half& value) noexcept
886-
: data_(value) {}
887-
AF_CONSTEXPR __DH__ half& operator=(__half&& value) noexcept {
888-
data_ = value;
884+
#if defined(NVCC) || defined(__CUDACC_RTC__)
885+
AF_CONSTEXPR __DH__ explicit half(__half value) noexcept
886+
#ifdef __CUDA_ARCH__
887+
: data_(value) {
888+
}
889+
#else
890+
: data_(*reinterpret_cast<native_half_t*>(&value)) {
891+
}
892+
#endif
893+
AF_CONSTEXPR __DH__ half& operator=(__half value) noexcept {
894+
// NOTE Assignment to ushort from __half only works with device code.
895+
// using memcpy instead
896+
data_ = *reinterpret_cast<native_half_t*>(&value);
889897
return *this;
890898
}
891899
#endif
@@ -988,7 +996,11 @@ class alignas(2) half {
988996

989997
AF_CONSTEXPR static half infinity() {
990998
half out;
999+
#ifdef __CUDA_ARCH__
1000+
out.data_ = __half_raw{0x7C00};
1001+
#else
9911002
out.data_ = 0x7C00;
1003+
#endif
9921004
return out;
9931005
}
9941006
};

src/backend/cpu/kernel/random_engine.hpp

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,9 @@ static const double PI_VAL =
3232
3.1415926535897932384626433832795028841971693993751058209749445923078164;
3333

3434
// Conversion to half adapted from Random123
35-
#define HALF_FACTOR ((1.0f) / (std::numeric_limits<ushort>::max() + (1.0f)))
36-
#define HALF_HALF_FACTOR ((0.5f) * HALF_FACTOR)
37-
38-
// Conversion to half adapted from Random123
39-
#define SIGNED_HALF_FACTOR \
40-
((1.0f) / (std::numeric_limits<short>::max() + (1.0f)))
41-
#define SIGNED_HALF_HALF_FACTOR ((0.5f) * SIGNED_HALF_FACTOR)
42-
43-
#define DBL_FACTOR \
44-
((1.0) / (std::numeric_limits<unsigned long long>::max() + (1.0)))
45-
#define HALF_DBL_FACTOR ((0.5) * DBL_FACTOR)
46-
47-
// Conversion to floats adapted from Random123
48-
#define SIGNED_DBL_FACTOR \
49-
((1.0) / (std::numeric_limits<long long>::max() + (1.0)))
50-
#define SIGNED_HALF_DBL_FACTOR ((0.5) * SIGNED_DBL_FACTOR)
35+
constexpr float unsigned_half_factor =
36+
((1.0f) / (std::numeric_limits<ushort>::max() + (1.0f)));
37+
constexpr float unsigned_half_half_factor((0.5f) * unsigned_half_factor);
5138

5239
template<typename T>
5340
T transform(uint *val, uint index);
@@ -85,14 +72,19 @@ static float getFloatNegative11(uint *val, uint index) {
8572
// Generates rationals in [0, 1)
8673
common::half getHalf01(uint *val, uint index) {
8774
float v = val[index >> 1U] >> (16U * (index & 1U)) & 0x0000ffff;
88-
return static_cast<common::half>(fmaf(v, HALF_FACTOR, HALF_HALF_FACTOR));
75+
return static_cast<common::half>(
76+
fmaf(v, unsigned_half_factor, unsigned_half_half_factor));
8977
}
9078

9179
// Generates rationals in (-1, 1]
9280
static common::half getHalfNegative11(uint *val, uint index) {
9381
float v = val[index >> 1U] >> (16U * (index & 1U)) & 0x0000ffff;
94-
return static_cast<common::half>(
95-
fmaf(v, SIGNED_HALF_FACTOR, SIGNED_HALF_HALF_FACTOR));
82+
// Conversion to half adapted from Random123
83+
constexpr float factor =
84+
((1.0f) / (std::numeric_limits<short>::max() + (1.0f)));
85+
constexpr float half_factor = ((0.5f) * factor);
86+
87+
return static_cast<common::half>(fmaf(v, factor, half_factor));
9688
}
9789

9890
// Generates rationals in [0, 1)
@@ -160,8 +152,8 @@ double transform<double>(uint *val, uint index) {
160152
template<>
161153
common::half transform<common::half>(uint *val, uint index) {
162154
float v = val[index >> 1U] >> (16U * (index & 1U)) & 0x0000ffff;
163-
return static_cast<common::half>(1.f -
164-
fmaf(v, HALF_FACTOR, HALF_HALF_FACTOR));
155+
return static_cast<common::half>(
156+
1.f - fmaf(v, unsigned_half_factor, unsigned_half_half_factor));
165157
}
166158

167159
// Generates rationals in [-1, 1)

0 commit comments

Comments
 (0)