-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Enabled bfloat16 for cuda #27259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabled bfloat16 for cuda #27259
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -273,6 +273,77 @@ struct THCNumerics<float> { | |
| static inline __host__ __device__ bool isinf(float a) { return ::isinf(a); } | ||
| }; | ||
|
|
||
| template <> | ||
| struct THCNumerics<at::BFloat16> { | ||
| static inline __host__ __device__ at::BFloat16 min() { return at::numeric_limits<at::BFloat16>::lowest(); } | ||
| static inline __host__ __device__ at::BFloat16 max() { return at::numeric_limits<at::BFloat16>::max(); } | ||
| static inline __host__ __device__ at::BFloat16 lower_bound() { return at::numeric_limits<at::BFloat16>::lower_bound(); } | ||
| static inline __host__ __device__ at::BFloat16 upper_bound() { return at::numeric_limits<at::BFloat16>::upper_bound(); } | ||
|
|
||
| static inline __host__ __device__ bool lt(at::BFloat16 a, at::BFloat16 b) { return a < b; } | ||
| static inline __host__ __device__ bool le(at::BFloat16 a, at::BFloat16 b) { return a <= b; } | ||
| static inline __host__ __device__ bool gt(at::BFloat16 a, at::BFloat16 b) { return a > b; } | ||
| static inline __host__ __device__ bool ge(at::BFloat16 a, at::BFloat16 b) { return a >= b; } | ||
| static inline __host__ __device__ bool eq(at::BFloat16 a, at::BFloat16 b) { return a == b; } | ||
| static inline __host__ __device__ bool ne(at::BFloat16 a, at::BFloat16 b) { return a != b; } | ||
|
|
||
| static inline __host__ __device__ at::BFloat16 lgamma(at::BFloat16 a) { return lgammaf(a);} | ||
| static inline __host__ __device__ at::BFloat16 exp (at::BFloat16 a) { return expf(a); } | ||
| static inline __host__ __device__ at::BFloat16 exp10(at::BFloat16 a) { return exp10f(a); } | ||
| static inline __host__ __device__ at::BFloat16 log (at::BFloat16 a) { return logf(a); } | ||
| static inline __host__ __device__ at::BFloat16 log10(at::BFloat16 a) { return log10f(a); } | ||
| static inline __host__ __device__ at::BFloat16 log1p(at::BFloat16 a) { return log1pf(a); } | ||
| static inline __host__ __device__ at::BFloat16 log2 (at::BFloat16 a) { return log2f(a); } | ||
| static inline __host__ __device__ at::BFloat16 expm1(at::BFloat16 a) { return expm1f(a); } | ||
| static inline __host__ __device__ at::BFloat16 cos (at::BFloat16 a) { return cosf(a); } | ||
| static inline __host__ __device__ at::BFloat16 sin (at::BFloat16 a) { return sinf(a); } | ||
| static inline __host__ __device__ at::BFloat16 sqrt (at::BFloat16 a) { return sqrtf(a); } | ||
| static inline __host__ __device__ at::BFloat16 rsqrt(at::BFloat16 a) { return rsqrtf(a); } | ||
| static inline __host__ __device__ at::BFloat16 floor(at::BFloat16 a) { return floorf(a); } | ||
| static inline __host__ __device__ at::BFloat16 trunc(at::BFloat16 a) { return truncf(a); } | ||
| static inline __host__ __device__ at::BFloat16 acos (at::BFloat16 a) { return acosf(a); } | ||
| static inline __host__ __device__ at::BFloat16 cosh (at::BFloat16 a) { return coshf(a); } | ||
| static inline __host__ __device__ at::BFloat16 acosh(at::BFloat16 a) { return acoshf(a); } | ||
| static inline __host__ __device__ at::BFloat16 asin (at::BFloat16 a) { return asinf(a); } | ||
| static inline __host__ __device__ at::BFloat16 sinh (at::BFloat16 a) { return sinhf(a); } | ||
| static inline __host__ __device__ at::BFloat16 asinh(at::BFloat16 a) { return asinhf(a); } | ||
| static inline __host__ __device__ at::BFloat16 tan (at::BFloat16 a) { return tanf(a); } | ||
| static inline __host__ __device__ at::BFloat16 atan (at::BFloat16 a) { return atanf(a); } | ||
| static inline __host__ __device__ at::BFloat16 tanh (at::BFloat16 a) { return tanhf(a); } | ||
| static inline __host__ __device__ at::BFloat16 erf (at::BFloat16 a) { return erff(a); } | ||
| static inline __host__ __device__ at::BFloat16 erfc (at::BFloat16 a) { return erfcf(a); } | ||
| static inline __host__ __device__ at::BFloat16 abs (at::BFloat16 a) { return fabsf(a); } | ||
| static inline __host__ __device__ at::BFloat16 round(at::BFloat16 a) { return nearbyintf(a); } | ||
| static inline __host__ __device__ at::BFloat16 frac (at::BFloat16 a) { return a - truncf(a); } | ||
| static inline __host__ __device__ at::BFloat16 cinv (at::BFloat16 a) { return 1.0f / a; } | ||
| static inline __host__ __device__ at::BFloat16 add (at::BFloat16 a, at::BFloat16 b) { return a + b; } | ||
| static inline __host__ __device__ at::BFloat16 div (at::BFloat16 a, at::BFloat16 b) { return a / b; } | ||
| static inline __host__ __device__ at::BFloat16 mul (at::BFloat16 a, at::BFloat16 b) { return a * b; } | ||
| static inline __host__ __device__ at::BFloat16 sub (at::BFloat16 a, at::BFloat16 b) { return a - b; } | ||
| static inline __host__ __device__ at::BFloat16 pow (at::BFloat16 a, at::BFloat16 b) { return powf(a, b); } | ||
| static inline __host__ __device__ at::BFloat16 atan2(at::BFloat16 a, at::BFloat16 b) { return atan2f(a, b); } | ||
|
|
||
| static inline __host__ __device__ bool isnan(at::BFloat16 a) { | ||
| #ifdef _MSC_VER | ||
| // Windows requires this explicit conversion. The reason is unclear | ||
| // related issue with clang: https://reviews.llvm.org/D37906 | ||
| return ::isnan((float) a); | ||
| #else | ||
| return ::isnan(a); | ||
| #endif | ||
| } | ||
|
|
||
| static inline __host__ __device__ bool isinf(at::BFloat16 a) { | ||
| #ifdef _MSC_VER | ||
| // Windows requires this explicit conversion. The reason is unclear | ||
| // related issue with clang: https://reviews.llvm.org/D37906 | ||
| return ::isinf((float) a); | ||
| #else | ||
| return ::isinf(a); | ||
| #endif | ||
| } | ||
| }; | ||
|
|
||
|
Comment on lines
+326
to
+346
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ezyang, this part
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| // DEPRECATED: use math functions from std and cuda math API (if needed) | ||
| // note that the functions exp10,erfinv and cinv | ||
| // are not in the std namespace | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| #include <THC/THCTensorMasked.cuh> | ||
| #include <THC/THCTensor.hpp> | ||
|
|
||
| #include <THC/generic/THCTensorMasked.cu> | ||
| #include <THC/THCGenerateBFloat16Type.h> |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| #include <THC/THCTensorMathReduce.cuh> | ||
| #include <THC/THCTensor.hpp> | ||
|
|
||
| #include <THC/generic/THCTensorMathReduce.cu> | ||
| #include <THC/THCGenerateBFloat16Type.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: (you don't have to do anything about this, just fyi for later) generally people prefer
static_cast<float>in C++ as it prevents the C style cast from, e.g., being treated like areinterpret_cast