Skip to content

Commit 1fbabff

Browse files
syed-ahmedfacebook-github-bot
authored andcommitted
Refactor THCNumerics and add common math functions for at::Half (#10301)
Summary: **Summary**: This PR is a followup of mruberry's #9318. It tries to achieve the following: - Specializing std common math functions for `at::Half` type. - Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`. - Update `THCNumerics.cuh` with new usage and comments to demonstrate the best practice for developers and hence, making way for its deprecation. - Remove legacy/redundant code path. - Remove unused CUDA HALF macros (see separate PR #10147) **Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed: - All arithmetic can now be done in ATen using binary cuda kernel or CUDA tensor pointwise apply (check #8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float. - Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h` - Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call. - Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for `at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/hip#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP). Here are some reference PRs that was handy in refactoring TH into ATen: - #6786 - #5475 - #9401 - #8689 - #8919 Pull Request resolved: #10301 Differential Revision: D9204758 Pulled By: soumith fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
1 parent 87a7840 commit 1fbabff

File tree

13 files changed

+340
-383
lines changed

13 files changed

+340
-383
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#pragma once
2+
3+
#include <cuda.h>
4+
#include <limits.h>
5+
6+
// NumericLimits.cuh is a holder for numeric limits definitions of commonly used
7+
// types. This header is very specific to ROCm HIP and may be removed in the future.
8+
// This header is derived from the legacy THCNumerics.cuh.
9+
10+
namespace at{
11+
12+
template <typename T>
13+
struct numeric_limits {
14+
};
15+
16+
// WARNING: the following at::numeric_limits definitions are there only to support
17+
// HIP compilation for the moment. Use std::numeric_limits if you are not
18+
// compiling for ROCm.
19+
// from @colesbury: "The functions on numeric_limits aren't marked with
20+
// __device__ which is why they don't work with ROCm. CUDA allows them
21+
// because they're constexpr."
22+
template <>
23+
struct numeric_limits<uint8_t> {
24+
static inline __host__ __device__ uint8_t lowest() { return 0; }
25+
static inline __host__ __device__ uint8_t max() { return UINT8_MAX; }
26+
};
27+
28+
template <>
29+
struct numeric_limits<int8_t> {
30+
static inline __host__ __device__ int8_t lowest() { return INT8_MIN; }
31+
static inline __host__ __device__ int8_t max() { return INT8_MAX; }
32+
};
33+
34+
template <>
35+
struct numeric_limits<int16_t> {
36+
static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
37+
static inline __host__ __device__ int16_t max() { return INT16_MAX; }
38+
};
39+
40+
template <>
41+
struct numeric_limits<int32_t> {
42+
static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
43+
static inline __host__ __device__ int32_t max() { return INT32_MAX; }
44+
};
45+
46+
template <>
47+
struct numeric_limits<int64_t> {
48+
#ifdef _MSC_VER
49+
static inline __host__ __device__ int64_t lowest() { return _I64_MIN; }
50+
static inline __host__ __device__ int64_t max() { return _I64_MAX; }
51+
#else
52+
static inline __host__ __device__ int64_t lowest() { return INT64_MIN; }
53+
static inline __host__ __device__ int64_t max() { return INT64_MAX; }
54+
#endif
55+
};
56+
57+
template <>
58+
struct numeric_limits<at::Half> {
59+
static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits); }
60+
static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits); }
61+
};
62+
63+
template <>
64+
struct numeric_limits<float> {
65+
static inline __host__ __device__ float lowest() { return -FLT_MAX; }
66+
static inline __host__ __device__ float max() { return FLT_MAX; }
67+
};
68+
69+
template <>
70+
struct numeric_limits<double> {
71+
static inline __host__ __device__ double lowest() { return -DBL_MAX; }
72+
static inline __host__ __device__ double max() { return DBL_MAX; }
73+
};
74+
75+
} // namespace at

aten/src/ATen/native/cuda/SoftMax.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
#include <THC/THCTensorMathReduce.cuh>
77
#include <THC/THCTensorSort.cuh>
88
#include <THC/THCThrustAllocator.cuh>
9-
#include <THC/THCNumerics.cuh>
109

1110
#include "ATen/AccumulateType.h"
11+
#include "ATen/cuda/NumericLimits.cuh"
1212

1313

1414
namespace at {
@@ -200,7 +200,7 @@ __global__ void cunn_SpatialSoftMaxForward(
200200
////////////////////////////////////////////////////////////
201201

202202
if (blockDim.x > 1) {
203-
accscalar_t max_input = THCNumerics<accscalar_t>::min();
203+
accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
204204
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
205205
const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
206206
max_input = Max<accscalar_t>()(max_input, value);
@@ -217,7 +217,7 @@ __global__ void cunn_SpatialSoftMaxForward(
217217
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
218218
output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
219219
} else {
220-
accscalar_t max_input = THCNumerics<accscalar_t>::min();
220+
accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
221221
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
222222
const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
223223
max_input = Max<accscalar_t>()(max_input, value);
@@ -403,9 +403,9 @@ cunn_SoftMaxForward(scalar_t *output, scalar_t *input, int classes)
403403

404404
// find the max
405405
accscalar_t threadMax = ilpReduce<MaxFloat, ILP, scalar_t, accscalar_t>(
406-
input, classes, MaxFloat<scalar_t, accscalar_t>(), -THCNumerics<accscalar_t>::max());
406+
input, classes, MaxFloat<scalar_t, accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
407407
accscalar_t max_k = blockReduce<Max, accscalar_t>(
408-
sdata, threadMax, Max<accscalar_t>(), -THCNumerics<accscalar_t>::max());
408+
sdata, threadMax, Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
409409

410410
// reduce all values
411411
accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(

aten/src/ATen/test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ list(APPEND ATen_CUDA_TEST_SRCS
2525
${CMAKE_CURRENT_SOURCE_DIR}/integer_divider_test.cu
2626
${CMAKE_CURRENT_SOURCE_DIR}/cuda_rng_test.cpp
2727
${CMAKE_CURRENT_SOURCE_DIR}/apply_test.cpp
28-
${CMAKE_CURRENT_SOURCE_DIR}/stream_test.cpp)
28+
${CMAKE_CURRENT_SOURCE_DIR}/stream_test.cpp
29+
${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu)
2930
if (CUDNN_FOUND)
3031
list(APPEND ATen_CUDA_TEST_SRCS
3132
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_test.cpp)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#define CATCH_CONFIG_MAIN
2+
#include "catch.hpp"
3+
4+
#include "ATen/ATen.h"
5+
#include "ATen/cuda/NumericLimits.cuh"
6+
#include "cuda.h"
7+
#include "cuda_fp16.h"
8+
#include "cuda_runtime.h"
9+
10+
#include <assert.h>
11+
12+
using namespace at;
13+
14+
__device__ void test(){
15+
16+
// test half construction and implicit conversions in device
17+
assert(Half(3) == Half(3.0f));
18+
assert(static_cast<Half>(3.0f) == Half(3.0f));
19+
// there is no float <=> __half implicit conversion
20+
assert(static_cast<Half>(3.0f) == 3.0f);
21+
22+
__half a = __float2half(3.0f);
23+
__half b = __float2half(2.0f);
24+
__half c = a - Half(b);
25+
assert(static_cast<Half>(c) == Half(1.0));
26+
27+
// asserting if the functions used on
28+
// half types give almost equivalent results when using
29+
// functions on double.
30+
// The purpose of these asserts are to test the device side
31+
// half API for the common mathematical functions.
32+
// Note: When calling std math functions from device, don't
33+
// use the std namespace, but just "::" so that the function
34+
// gets resolved from nvcc math_functions.hpp
35+
36+
float threshold = 0.00001;
37+
assert(::abs(::lgamma(Half(10.0)) - ::lgamma(10.0f)) <= threshold);
38+
assert(::abs(::exp(Half(1.0)) - ::exp(1.0f)) <= threshold);
39+
assert(::abs(::log(Half(1.0)) - ::log(1.0f)) <= threshold);
40+
assert(::abs(::log10(Half(1000.0)) - ::log10(1000.0f)) <= threshold);
41+
assert(::abs(::log1p(Half(0.0)) - ::log1p(0.0f)) <= threshold);
42+
assert(::abs(::log2(Half(1000.0)) - ::log2(1000.0f)) <= threshold);
43+
assert(::abs(::expm1(Half(1.0)) - ::expm1(1.0f)) <= threshold);
44+
assert(::abs(::cos(Half(0.0)) - ::cos(0.0f)) <= threshold);
45+
assert(::abs(::sin(Half(0.0)) - ::sin(0.0f)) <= threshold);
46+
assert(::abs(::sqrt(Half(100.0)) - ::sqrt(100.0f)) <= threshold);
47+
assert(::abs(::ceil(Half(2.4)) - ::ceil(2.4f)) <= threshold);
48+
assert(::abs(::floor(Half(2.7)) - ::floor(2.7f)) <= threshold);
49+
assert(::abs(::trunc(Half(2.7)) - ::trunc(2.7f)) <= threshold);
50+
assert(::abs(::acos(Half(-1.0)) - ::acos(-1.0f)) <= threshold);
51+
assert(::abs(::cosh(Half(1.0)) - ::cosh(1.0f)) <= threshold);
52+
assert(::abs(::acosh(Half(1.0)) - ::acosh(1.0f)) <= threshold);
53+
assert(::abs(::asin(Half(1.0)) - ::asin(1.0f)) <= threshold);
54+
assert(::abs(::sinh(Half(1.0)) - ::sinh(1.0f)) <= threshold);
55+
assert(::abs(::asinh(Half(1.0)) - ::asinh(1.0f)) <= threshold);
56+
assert(::abs(::tan(Half(0.0)) - ::tan(0.0f)) <= threshold);
57+
assert(::abs(::atan(Half(1.0)) - ::atan(1.0f)) <= threshold);
58+
assert(::abs(::tanh(Half(1.0)) - ::tanh(1.0f)) <= threshold);
59+
assert(::abs(::erf(Half(10.0)) - ::erf(10.0f)) <= threshold);
60+
assert(::abs(::erfc(Half(10.0)) - ::erfc(10.0f)) <= threshold);
61+
assert(::abs(::abs(Half(-3.0)) - ::abs(-3.0f)) <= threshold);
62+
assert(::abs(::round(Half(2.3)) - ::round(2.3f)) <= threshold);
63+
assert(::abs(::pow(Half(2.0), Half(10.0)) - ::pow(2.0f, 10.0f)) <= threshold);
64+
assert(::abs(::atan2(Half(7.0), Half(0.0)) - ::atan2(7.0f, 0.0f)) <= threshold);
65+
// note: can't use namespace on isnan and isinf in device code
66+
#ifdef _MSC_VER
67+
// Windows requires this explicit conversion. The reason is unclear
68+
// related issue with clang: https://reviews.llvm.org/D37906
69+
assert(::abs(::isnan((float)Half(0.0)) - ::isnan(0.0f)) <= threshold);
70+
assert(::abs(::isinf((float)Half(0.0)) - ::isinf(0.0f)) <= threshold);
71+
#else
72+
assert(::abs(::isnan(Half(0.0)) - ::isnan(0.0f)) <= threshold);
73+
assert(::abs(::isinf(Half(0.0)) - ::isinf(0.0f)) <= threshold);
74+
#endif
75+
}
76+
77+
__global__ void kernel(){
78+
test();
79+
}
80+
81+
void launch_function(){
82+
kernel<<<1,1>>>();
83+
}
84+
85+
TEST_CASE( "half common math functions tests in device", "[cuda]" ) {
86+
launch_function();
87+
cudaError_t err = cudaDeviceSynchronize();
88+
REQUIRE(err == cudaSuccess);
89+
}
90+

aten/src/ATen/test/half_test.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
#include <iostream>
66
#include <limits>
77
#include <sstream>
8+
#include <cmath>
89
#include <type_traits>
10+
#include "test_seed.h"
11+
#include "test_assert.h"
912

1013
using namespace at;
1114

@@ -115,3 +118,43 @@ ASSERT_SAME_TYPE(max_exponent);
115118
ASSERT_SAME_TYPE(max_exponent10);
116119
ASSERT_SAME_TYPE(traps);
117120
ASSERT_SAME_TYPE(tinyness_before);
121+
122+
TEST_CASE( "half common math functions test", "[]" ) {
123+
float threshold = 0.00001;
124+
assert(std::abs(std::lgamma(Half(10.0)) - std::lgamma(10.0f)) <= threshold);
125+
assert(std::abs(std::exp(Half(1.0)) - std::exp(1.0f)) <= threshold);
126+
assert(std::abs(std::log(Half(1.0)) - std::log(1.0f)) <= threshold);
127+
assert(std::abs(std::log10(Half(1000.0)) - std::log10(1000.0f)) <= threshold);
128+
assert(std::abs(std::log1p(Half(0.0)) - std::log1p(0.0f)) <= threshold);
129+
assert(std::abs(std::log2(Half(1000.0)) - std::log2(1000.0f)) <= threshold);
130+
assert(std::abs(std::expm1(Half(1.0)) - std::expm1(1.0f)) <= threshold);
131+
assert(std::abs(std::cos(Half(0.0)) - std::cos(0.0f)) <= threshold);
132+
assert(std::abs(std::sin(Half(0.0)) - std::sin(0.0f)) <= threshold);
133+
assert(std::abs(std::sqrt(Half(100.0)) - std::sqrt(100.0f)) <= threshold);
134+
assert(std::abs(std::ceil(Half(2.4)) - std::ceil(2.4f)) <= threshold);
135+
assert(std::abs(std::floor(Half(2.7)) - std::floor(2.7f)) <= threshold);
136+
assert(std::abs(std::trunc(Half(2.7)) - std::trunc(2.7f)) <= threshold);
137+
assert(std::abs(std::acos(Half(-1.0)) - std::acos(-1.0f)) <= threshold);
138+
assert(std::abs(std::cosh(Half(1.0)) - std::cosh(1.0f)) <= threshold);
139+
assert(std::abs(std::acosh(Half(1.0)) - std::acosh(1.0f)) <= threshold);
140+
assert(std::abs(std::asin(Half(1.0)) - std::asin(1.0f)) <= threshold);
141+
assert(std::abs(std::sinh(Half(1.0)) - std::sinh(1.0f)) <= threshold);
142+
assert(std::abs(std::asinh(Half(1.0)) - std::asinh(1.0f)) <= threshold);
143+
assert(std::abs(std::tan(Half(0.0)) - std::tan(0.0f)) <= threshold);
144+
assert(std::abs(std::atan(Half(1.0)) - std::atan(1.0f)) <= threshold);
145+
assert(std::abs(std::tanh(Half(1.0)) - std::tanh(1.0f)) <= threshold);
146+
assert(std::abs(std::erf(Half(10.0)) - std::erf(10.0f)) <= threshold);
147+
assert(std::abs(std::erfc(Half(10.0)) - std::erfc(10.0f)) <= threshold);
148+
assert(std::abs(std::abs(Half(-3.0)) - std::abs(-3.0f)) <= threshold);
149+
assert(std::abs(std::round(Half(2.3)) - std::round(2.3f)) <= threshold);
150+
assert(std::abs(std::pow(Half(2.0), Half(10.0)) - std::pow(2.0f, 10.0f)) <= threshold);
151+
assert(std::abs(std::atan2(Half(7.0), Half(0.0)) - std::atan2(7.0f, 0.0f)) <= threshold);
152+
#ifdef __APPLE__
153+
// @TODO: can macos do implicit conversion of Half?
154+
assert(std::abs(std::isnan(static_cast<float>(Half(0.0))) - std::isnan(0.0f)) <= threshold);
155+
assert(std::abs(std::isinf(static_cast<float>(Half(0.0))) - std::isinf(0.0f)) <= threshold);
156+
#else
157+
assert(std::abs(std::isnan(Half(0.0)) - std::isnan(0.0f)) <= threshold);
158+
assert(std::abs(std::isinf(Half(0.0)) - std::isinf(0.0f)) <= threshold);
159+
#endif
160+
}

aten/src/THC/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ foreach(THC_TYPE Byte Char Short Int Long Half Float Double)
1818
endforeach()
1919
endforeach()
2020

21-
IF(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5)
22-
LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/THCHalf.cu)
23-
ENDIF()
24-
2521
set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
2622
${CMAKE_CURRENT_SOURCE_DIR}/THCCachingAllocator.cpp
2723
${CMAKE_CURRENT_SOURCE_DIR}/THCCachingHostAllocator.cpp

aten/src/THC/THCAtomics.cuh

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
#include "THC.h"
55
#include "THCHalf.h"
66
#include "THCNumerics.cuh"
7-
8-
namespace at { struct Half; }
7+
#include "ATen/ATen.h"
98

109
template <typename T, size_t n>
1110
struct AtomicAddIntegerImpl;
@@ -118,8 +117,8 @@ static inline __device__ void atomicAdd(half *address, half val) {
118117
old = atomicCAS(address_as_ui, assumed, old);
119118
} while (assumed != old);
120119
}
121-
static inline __device__ void atomicAdd(at::Half *address, half val) {
122-
return atomicAdd(reinterpret_cast<half*>(address), val);
120+
static inline __device__ void atomicAdd(at::Half *address, at::Half val) {
121+
atomicAdd(reinterpret_cast<half*>(address), val);
123122
}
124123

125124
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)

aten/src/THC/THCHalf.cu

Lines changed: 0 additions & 51 deletions
This file was deleted.

aten/src/THC/THCHalf.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,7 @@ typedef __half_raw half;
1212
#endif
1313
#endif
1414

15-
THC_EXTERNC void THCFloat2Half(THCState *state, half *out, float *in, ptrdiff_t len);
16-
THC_EXTERNC void THCHalf2Float(THCState *state, float *out, half *in, ptrdiff_t len);
1715
THC_API half THC_float2half(float a);
1816
THC_API float THC_half2float(half a);
1917

20-
/* Check for native fp16 support on the current device (CC 5.3+) */
21-
THC_API int THC_nativeHalfInstructions(THCState *state);
22-
23-
/* Check for performant native fp16 support on the current device */
24-
THC_API int THC_fastHalfInstructions(THCState *state);
25-
2618
#endif

0 commit comments

Comments
 (0)