Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
cpu_bool: True
cuda_bool: True
cpu_bfloat16: True
cuda_bfloat16: True
device_guard: False
return: argument 0
options:
Expand All @@ -27,6 +28,7 @@
cpu_bool: True
cuda_bool: True
cpu_bfloat16: True
cuda_bfloat16: True
cname: maskedFill
variants: function
return: self
Expand All @@ -42,6 +44,7 @@
cpu_bool: True
cuda_bool: True
cpu_bfloat16: True
cuda_bfloat16: True
cname: maskedFillBool
variants: function
return: self
Expand All @@ -57,6 +60,7 @@
cpu_bool: True
cuda_bool: True
cpu_bfloat16: True
cuda_bfloat16: True
cname: maskedCopy
variants: function
return: self
Expand All @@ -71,6 +75,7 @@
cpu_bool: True
cuda_bool: True
cpu_bfloat16: True
cuda_bfloat16: True
cname: maskedCopyBool
variants: function
return: self
Expand All @@ -86,6 +91,7 @@
cpu_bool: True
cuda_bool: True
cpu_bfloat16: True
cuda_bfloat16: True
variants:
- function
return: argument 0
Expand All @@ -102,6 +108,7 @@
cpu_bool: True
cuda_bool: True
cpu_bfloat16: True
cuda_bfloat16: True
variants:
- function
return: argument 0
Expand All @@ -119,6 +126,7 @@
cpu_bool: True
cuda_bool: True
cpu_bfloat16: True
cuda_bfloat16: True
variants:
- function
return: argument 0
Expand All @@ -138,6 +146,7 @@
cpu_bool: True
cuda_bool: True
cpu_bfloat16: True
cuda_bfloat16: True
arguments:
- THTensor* self
]]
Expand Down Expand Up @@ -1589,6 +1598,7 @@
cpu_bool: True
cuda_bool: True
cpu_bfloat16: True
cuda_bfloat16: True
return: self
arguments:
- arg: THTensor* self
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/cuda/NumericLimits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ struct numeric_limits<at::Half> {
static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); }
};

template <>
struct numeric_limits<at::BFloat16> {
static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); }
static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); }
static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); }
static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); }
};

template <>
struct numeric_limits<float> {
static inline __host__ __device__ float lowest() { return -FLT_MAX; }
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def __getitem__(self, x):
'with_gil': bool,
'cpu_half': bool,
'cpu_bfloat16': bool,
'cuda_bfloat16': bool,
'deprecated': bool,
'cpu_bool': bool,
'cuda_bool': bool,
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cuda/BinaryArithmeticKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
namespace at { namespace native {

void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
auto alpha = alpha_scalar.to<scalar_t>();
gpu_kernel_with_scalars(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a + alpha * b;
Expand All @@ -36,7 +36,7 @@ void div_kernel_cuda(TensorIterator& iter) {
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.common_dtype(), "div_cuda", [&]() {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "div_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a / b;
});
Expand All @@ -51,7 +51,7 @@ void mul_kernel_cuda(TensorIterator& iter) {
return a && b;
});
} else {
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.common_dtype(), "mul_cuda", [&]() {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "mul_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a * b;
});
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/CUDAScalar.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ namespace native {

Scalar _local_scalar_dense_cuda(const Tensor& self) {
Scalar r;
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
AT_DISPATCH_ALL_TYPES_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
scalar_t value;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_CHECK(cudaMemcpyAsync(&value, self.data_ptr<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/FillKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace at { namespace native {

void fill_kernel_cuda(TensorIterator& iter, Scalar value) {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half, iter.dtype(), "fill_cuda", [&]() {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "fill_cuda", [&]() {
auto value_converted = value.to<scalar_t>();
gpu_kernel(iter, [value_converted]GPU_LAMBDA() -> scalar_t {
return value_converted;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/IndexKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayR

static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
AT_ASSERTM(!accumulate, "index_put does not support accumulate=true");
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_put", [&] {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "index_put", [&] {
using dtype = OpaqueType<sizeof(scalar_t)>;
index_put_kernel_impl<dtype>(iter, index_size, index_stride);
});
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def run(paths):
declaration['matches_jit_signature'] = func.get('matches_jit_signature', True)
declaration['cpu_half'] = func.get('cpu_half', False)
declaration['cpu_bfloat16'] = func.get('cpu_bfloat16', False)
declaration['cuda_bfloat16'] = func.get('cuda_bfloat16', False)
declaration['cpu_bool'] = func.get('cpu_bool', False)
declaration['cuda_bool'] = func.get('cuda_bool', False)
declaration['deprecated'] = func.get('deprecated', False)
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/preprocess_declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ def expand(types):
if 'CPU' in backend_types:
backend_types['CPU'].discard('Half')

# special case remove BFloat16 for cpu unless it is explicitly enabled
# special case remove BFloat16 for cpu and cuda unless it is explicitly enabled
if not option.get('cpu_bfloat16', False):
if 'CPU' in backend_types:
backend_types['CPU'].discard('BFloat16')

# TODO: remove this hack once support for a bfloat16 tensor for CUDA is enabled
if 'CUDA' in backend_types:
backend_types['CUDA'].discard('BFloat16')
if not option.get('cuda_bfloat16', False):
if 'CUDA' in backend_types:
backend_types['CUDA'].discard('BFloat16')

# special cases remove bool for cpu and cuda unless it is explicitly enabled
if not option.get('cpu_bool', False):
Expand Down
12 changes: 10 additions & 2 deletions aten/src/THC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ set(extra_src)
# loop over all types
foreach(THC_TYPE Byte Char Short Int Long Half Float Double)
# loop over files which need to be split between types (because of long compile times)
foreach(THC_FILE TensorSort TensorMathCompareT TensorMathPointwise TensorMathCompare TensorMathReduce TensorMasked)
foreach(THC_FILE TensorSort TensorMathPointwise TensorMathReduce TensorMasked)
if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}${THC_TYPE}.cu")
FILE(WRITE "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}${THC_TYPE}.cu"
"#include <THC/THC${THC_FILE}.cuh>\n#include <THC/THCTensor.hpp>\n\n#include <THC/generic/THC${THC_FILE}.cu>\n#include <THC/THCGenerate${THC_TYPE}Type.h>\n")
Expand All @@ -18,14 +18,22 @@ foreach(THC_TYPE Byte Char Short Int Long Half Float Double)
endforeach()
endforeach()

foreach(THC_FILE TensorMathCompareT TensorMathCompare TensorMathReduce TensorMasked TensorMathPointwise)
foreach(THC_FILE TensorMathPointwise TensorMathReduce TensorMasked)
if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}Bool.cu")
FILE(WRITE "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}Bool.cu"
"#include <THC/THC${THC_FILE}.cuh>\n#include <THC/THCTensor.hpp>\n\n#include <THC/generic/THC${THC_FILE}.cu>\n#include <THC/THCGenerateBoolType.h>\n")
endif()
LIST(APPEND extra_src "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}Bool.cu")
endforeach()

foreach(THC_FILE TensorMathReduce TensorMasked)
if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}BFloat16.cu")
FILE(WRITE "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}BFloat16.cu"
"#include <THC/THC${THC_FILE}.cuh>\n#include <THC/THCTensor.hpp>\n\n#include <THC/generic/THC${THC_FILE}.cu>\n#include <THC/THCGenerateBFloat16Type.h>\n")
endif()
LIST(APPEND extra_src "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}BFloat16.cu")
endforeach()

set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
${CMAKE_CURRENT_SOURCE_DIR}/THCCachingHostAllocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THCGeneral.cpp
Expand Down
2 changes: 2 additions & 0 deletions aten/src/THC/THCGenerateBFloat16Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <c10/util/BFloat16.h>

#define scalar_t at::BFloat16
#define accreal float
#define Real BFloat16

#define CReal CudaBFloat16
Expand All @@ -12,6 +13,7 @@
#line 1 THC_GENERIC_FILE
#include THC_GENERIC_FILE
#undef scalar_t
#undef accreal
#undef Real

#undef CReal
Expand Down
71 changes: 71 additions & 0 deletions aten/src/THC/THCNumerics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,77 @@ struct THCNumerics<float> {
static inline __host__ __device__ bool isinf(float a) { return ::isinf(a); }
};

template <>
struct THCNumerics<at::BFloat16> {
static inline __host__ __device__ at::BFloat16 min() { return at::numeric_limits<at::BFloat16>::lowest(); }
static inline __host__ __device__ at::BFloat16 max() { return at::numeric_limits<at::BFloat16>::max(); }
static inline __host__ __device__ at::BFloat16 lower_bound() { return at::numeric_limits<at::BFloat16>::lower_bound(); }
static inline __host__ __device__ at::BFloat16 upper_bound() { return at::numeric_limits<at::BFloat16>::upper_bound(); }

static inline __host__ __device__ bool lt(at::BFloat16 a, at::BFloat16 b) { return a < b; }
static inline __host__ __device__ bool le(at::BFloat16 a, at::BFloat16 b) { return a <= b; }
static inline __host__ __device__ bool gt(at::BFloat16 a, at::BFloat16 b) { return a > b; }
static inline __host__ __device__ bool ge(at::BFloat16 a, at::BFloat16 b) { return a >= b; }
static inline __host__ __device__ bool eq(at::BFloat16 a, at::BFloat16 b) { return a == b; }
static inline __host__ __device__ bool ne(at::BFloat16 a, at::BFloat16 b) { return a != b; }

static inline __host__ __device__ at::BFloat16 lgamma(at::BFloat16 a) { return lgammaf(a);}
static inline __host__ __device__ at::BFloat16 exp (at::BFloat16 a) { return expf(a); }
static inline __host__ __device__ at::BFloat16 exp10(at::BFloat16 a) { return exp10f(a); }
static inline __host__ __device__ at::BFloat16 log (at::BFloat16 a) { return logf(a); }
static inline __host__ __device__ at::BFloat16 log10(at::BFloat16 a) { return log10f(a); }
static inline __host__ __device__ at::BFloat16 log1p(at::BFloat16 a) { return log1pf(a); }
static inline __host__ __device__ at::BFloat16 log2 (at::BFloat16 a) { return log2f(a); }
static inline __host__ __device__ at::BFloat16 expm1(at::BFloat16 a) { return expm1f(a); }
static inline __host__ __device__ at::BFloat16 cos (at::BFloat16 a) { return cosf(a); }
static inline __host__ __device__ at::BFloat16 sin (at::BFloat16 a) { return sinf(a); }
static inline __host__ __device__ at::BFloat16 sqrt (at::BFloat16 a) { return sqrtf(a); }
static inline __host__ __device__ at::BFloat16 rsqrt(at::BFloat16 a) { return rsqrtf(a); }
static inline __host__ __device__ at::BFloat16 floor(at::BFloat16 a) { return floorf(a); }
static inline __host__ __device__ at::BFloat16 trunc(at::BFloat16 a) { return truncf(a); }
static inline __host__ __device__ at::BFloat16 acos (at::BFloat16 a) { return acosf(a); }
static inline __host__ __device__ at::BFloat16 cosh (at::BFloat16 a) { return coshf(a); }
static inline __host__ __device__ at::BFloat16 acosh(at::BFloat16 a) { return acoshf(a); }
static inline __host__ __device__ at::BFloat16 asin (at::BFloat16 a) { return asinf(a); }
static inline __host__ __device__ at::BFloat16 sinh (at::BFloat16 a) { return sinhf(a); }
static inline __host__ __device__ at::BFloat16 asinh(at::BFloat16 a) { return asinhf(a); }
static inline __host__ __device__ at::BFloat16 tan (at::BFloat16 a) { return tanf(a); }
static inline __host__ __device__ at::BFloat16 atan (at::BFloat16 a) { return atanf(a); }
static inline __host__ __device__ at::BFloat16 tanh (at::BFloat16 a) { return tanhf(a); }
static inline __host__ __device__ at::BFloat16 erf (at::BFloat16 a) { return erff(a); }
static inline __host__ __device__ at::BFloat16 erfc (at::BFloat16 a) { return erfcf(a); }
static inline __host__ __device__ at::BFloat16 abs (at::BFloat16 a) { return fabsf(a); }
static inline __host__ __device__ at::BFloat16 round(at::BFloat16 a) { return nearbyintf(a); }
static inline __host__ __device__ at::BFloat16 frac (at::BFloat16 a) { return a - truncf(a); }
static inline __host__ __device__ at::BFloat16 cinv (at::BFloat16 a) { return 1.0f / a; }
static inline __host__ __device__ at::BFloat16 add (at::BFloat16 a, at::BFloat16 b) { return a + b; }
static inline __host__ __device__ at::BFloat16 div (at::BFloat16 a, at::BFloat16 b) { return a / b; }
static inline __host__ __device__ at::BFloat16 mul (at::BFloat16 a, at::BFloat16 b) { return a * b; }
static inline __host__ __device__ at::BFloat16 sub (at::BFloat16 a, at::BFloat16 b) { return a - b; }
static inline __host__ __device__ at::BFloat16 pow (at::BFloat16 a, at::BFloat16 b) { return powf(a, b); }
static inline __host__ __device__ at::BFloat16 atan2(at::BFloat16 a, at::BFloat16 b) { return atan2f(a, b); }

static inline __host__ __device__ bool isnan(at::BFloat16 a) {
#ifdef _MSC_VER
// Windows requires this explicit conversion. The reason is unclear
// related issue with clang: https://reviews.llvm.org/D37906
return ::isnan((float) a);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: (you don't have to do anything about this, just fyi for later) generally people prefer static_cast<float> in C++ as it prevents the C style cast from, e.g., being treated like a reinterpret_cast

#else
return ::isnan(a);
#endif
}

static inline __host__ __device__ bool isinf(at::BFloat16 a) {
#ifdef _MSC_VER
// Windows requires this explicit conversion. The reason is unclear
// related issue with clang: https://reviews.llvm.org/D37906
return ::isinf((float) a);
#else
return ::isinf(a);
#endif
}
};

Comment on lines +326 to +346
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang, this part

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff looks very familiar to me. Is it a respin of a diff you did previously?
The diff is identical to the one that have been reverted except this part which fixes win build.

// DEPRECATED: use math functions from std and cuda math API (if needed)
// note that the functions exp10,erfinv and cinv
// are not in the std namespace
Expand Down
3 changes: 3 additions & 0 deletions aten/src/THC/THCTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,6 @@ struct NonZeroOp<bool>

#include <THC/generic/THCTensorMath.cu>
#include <THC/THCGenerateBoolType.h>

#include <THC/generic/THCTensorMath.cu>
#include <THC/THCGenerateBFloat16Type.h>
9 changes: 9 additions & 0 deletions aten/src/THC/THCTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include <THC/generic/THCTensorMath.h>
#include <THC/THCGenerateBoolType.h>

#include <THC/generic/THCTensorMath.h>
#include <THC/THCGenerateBFloat16Type.h>

#include <THC/generic/THCTensorMathBlas.h>
#include <THC/THCGenerateAllTypes.h>

Expand All @@ -34,6 +37,9 @@
#include <THC/generic/THCTensorMathReduce.h>
#include <THC/THCGenerateBoolType.h>

#include <THC/generic/THCTensorMathReduce.h>
#include <THC/THCGenerateBFloat16Type.h>

#include <THC/generic/THCTensorMathScan.h>
#include <THC/THCGenerateAllTypes.h>

Expand All @@ -58,6 +64,9 @@
#include <THC/generic/THCTensorIndex.h>
#include <THC/THCGenerateBoolType.h>

#include <THC/generic/THCTensorMasked.h>
#include <THC/THCGenerateBFloat16Type.h>

#include <THC/generic/THCTensorSort.h>
#include <THC/THCGenerateAllTypes.h>

Expand Down
5 changes: 5 additions & 0 deletions aten/src/THC/generated/THCTensorMaskedBFloat16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <THC/THCTensorMasked.cuh>
#include <THC/THCTensor.hpp>

#include <THC/generic/THCTensorMasked.cu>
#include <THC/THCGenerateBFloat16Type.h>
5 changes: 5 additions & 0 deletions aten/src/THC/generated/THCTensorMathReduceBFloat16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <THC/THCTensorMathReduce.cuh>
#include <THC/THCTensor.hpp>

#include <THC/generic/THCTensorMathReduce.cu>
#include <THC/THCGenerateBFloat16Type.h>
3 changes: 3 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,9 @@ def assertTensorsEqual(a, b):
if (a.device.type == 'cpu' and (a.dtype == torch.float16 or a.dtype == torch.bfloat16)):
# CPU half and bfloat16 tensors don't have the methods we need below
a = a.to(torch.float32)
if (a.device.type == 'cuda' and a.dtype == torch.bfloat16):
# CUDA bfloat16 tensors don't have the methods we need below
a = a.to(torch.float32)
b = b.to(a)

if (a.dtype == torch.bool) != (b.dtype == torch.bool):
Expand Down
Loading