Skip to content

Commit e72b617

Browse files
izdebyfacebook-github-bot
authored andcommitted
Intoducing bfloat16 type (#21522)
Summary: Pull Request resolved: #21522 ghimport-source-id: 4803f19 Test Plan: Imported from OSS Differential Revision: D15819369 Pulled By: izdeby fbshipit-source-id: 46408dc316a5c4dc644a736dc42da2422b34bcb9
1 parent de5a481 commit e72b617

File tree

5 files changed

+157
-2
lines changed

5 files changed

+157
-2
lines changed

c10/test/util/bfloat16_test.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#include <c10/util/BFloat16.h>
2+
#include <gtest/gtest.h>
3+
4+
namespace {
5+
float float_from_bytes(
6+
uint32_t sign,
7+
uint32_t exponent,
8+
uint32_t fraction
9+
) {
10+
uint32_t bytes;
11+
bytes = 0;
12+
bytes |= sign;
13+
bytes <<= 8;
14+
bytes |= exponent;
15+
bytes <<= 23;
16+
bytes |= fraction;
17+
18+
float res;
19+
std::memcpy(&res, &bytes, sizeof(res));
20+
return res;
21+
}
22+
23+
TEST(BFloat16Conversion, FloatToBFloat16AndBack) {
24+
float in[100];
25+
for (int i = 0; i < 100; ++i) {
26+
in[i] = i + 1.25;
27+
}
28+
29+
c10::BFloat16 bfloats[100];
30+
float out[100];
31+
32+
for (int i = 0; i < 100; ++i) {
33+
bfloats[i].x = c10::detail::bits_from_f32(in[i]);
34+
out[i] = c10::detail::f32_from_bits(bfloats[i].x);
35+
36+
// The relative error should be less than 1/(2^7) since bfloat16
37+
// has 7 bits mantissa.
38+
EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128);
39+
}
40+
}
41+
42+
TEST(BFloat16Conversion, NaN) {
43+
float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF);
44+
EXPECT_TRUE(std::isnan(inNaN));
45+
46+
c10::BFloat16 a = c10::BFloat16(inNaN);
47+
float out = c10::detail::f32_from_bits(a.x);
48+
49+
EXPECT_TRUE(std::isnan(out));
50+
}
51+
52+
TEST(BFloat16Conversion, Inf) {
53+
float inInf = float_from_bytes(0, 0xFF, 0);
54+
EXPECT_TRUE(std::isinf(inInf));
55+
56+
c10::BFloat16 a = c10::BFloat16(inInf);
57+
float out = c10::detail::f32_from_bits(a.x);
58+
59+
EXPECT_TRUE(std::isinf(out));
60+
}
61+
62+
TEST(BFloat16Conversion, SmallestDenormal) {
63+
float in = std::numeric_limits<float>::denorm_min(); // The smallest non-zero subnormal number
64+
c10::BFloat16 a = c10::BFloat16(in);
65+
float out = c10::detail::f32_from_bits(a.x);
66+
67+
EXPECT_FLOAT_EQ(in, out);
68+
}
69+
} // namespace

c10/util/BFloat16-inl.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#include <c10/macros/Macros.h>
4+
5+
namespace c10 {
6+
7+
/// Constructors
8+
inline C10_HOST_DEVICE BFloat16::BFloat16(float value) {
9+
x = detail::bits_from_f32(value);
10+
}
11+
12+
/// Implicit conversions
13+
inline C10_HOST_DEVICE BFloat16::operator float() const {
14+
return detail::f32_from_bits(x);
15+
}
16+
17+
} // namespace c10

c10/util/BFloat16.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#pragma once
2+
3+
// Defines the bloat16 type (brain floating-point). This representation uses
4+
// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
5+
6+
#include <c10/macros/Macros.h>
7+
#include <cmath>
8+
#include <cstring>
9+
10+
namespace c10 {
11+
12+
namespace detail {
13+
inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
14+
float res = 0;
15+
uint32_t tmp = src;
16+
tmp <<= 16;
17+
18+
#ifdef __HIP_PLATFORM_HCC__
19+
float* tempRes;
20+
21+
// We should be using memcpy in order to respect the strict aliasing rule
22+
// but it fails in the HIP environment.
23+
tempRes = reinterpret_cast<float*>(&tmp);
24+
res = *tempRes;
25+
#else
26+
std::memcpy(&res, &tmp, sizeof(tmp));
27+
#endif
28+
29+
return res;
30+
}
31+
32+
inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
33+
uint32_t res = 0;
34+
35+
#ifdef __HIP_PLATFORM_HCC__
36+
// We should be using memcpy in order to respect the strict aliasing rule
37+
// but it fails in the HIP environment.
38+
uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
39+
res = *tempRes;
40+
#else
41+
std::memcpy(&res, &src, sizeof(res));
42+
#endif
43+
44+
return res >> 16;
45+
}
46+
} // namespace detail
47+
48+
struct alignas(2) BFloat16 {
49+
uint16_t x;
50+
51+
// HIP wants __host__ __device__ tag, CUDA does not
52+
#ifdef __HIP_PLATFORM_HCC__
53+
C10_HOST_DEVICE BFloat16() = default;
54+
#else
55+
BFloat16() = default;
56+
57+
#endif
58+
59+
explicit inline C10_HOST_DEVICE BFloat16(float value);
60+
explicit inline C10_HOST_DEVICE operator float() const;
61+
};
62+
63+
} // namespace c10
64+
65+
66+
#include <c10/util/BFloat16-inl.h>

c10/util/typeid.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(28, at::Half*)
8585
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(29, c10::qint8)
8686
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(30, c10::quint8)
8787
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(31, c10::qint32)
88-
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(32, _CaffeHighestPreallocatedTypeId)
88+
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(32, at::BFloat16)
89+
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(33, _CaffeHighestPreallocatedTypeId)
8990

9091
} // namespace caffe2

c10/util/typeid.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <c10/util/qint32.h>
2828
#include <c10/util/qint8.h>
2929
#include <c10/util/quint8.h>
30+
#include <c10/util/BFloat16.h>
3031

3132
/*
3233
* TypeIdentifier is a small type containing an id.
@@ -629,6 +630,7 @@ CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(28, at::Half*)
629630
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(29, c10::qint8)
630631
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(30, c10::quint8)
631632
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(31, c10::qint32)
632-
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(32, _CaffeHighestPreallocatedTypeId)
633+
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(32, at::BFloat16)
634+
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(33, _CaffeHighestPreallocatedTypeId)
633635

634636
} // namespace caffe2

0 commit comments

Comments
 (0)