Skip to content

Commit 028b719

Browse files
zasdfgbnmxuzhao9
authored andcommitted
CUDA BFloat16 unary ops part 2 (#44824)
Summary: Pull Request resolved: #44824 Reviewed By: mruberry Differential Revision: D23752360 Pulled By: ngimel fbshipit-source-id: 3aadaf9db9d4e4937aa38671e8589ecbeece709d
1 parent 5b91623 commit 028b719

File tree

3 files changed

+12
-18
lines changed

3 files changed

+12
-18
lines changed

aten/src/ATen/native/cuda/UnaryFractionKernels.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,8 @@ __host__ __device__ static inline c10::complex<T> reciprocal_wrapper(c10::comple
8989

9090
void reciprocal_kernel_cuda(TensorIterator& iter) {
9191
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "reciprocal_cuda", [&]() {
92-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "reciprocal_cuda", [&] {
93-
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
94-
return reciprocal_wrapper(a);
95-
});
92+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
93+
return reciprocal_wrapper(a);
9694
});
9795
});
9896
}

aten/src/ATen/native/cuda/UnaryOpsKernel.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,8 @@ void bitwise_not_kernel_cuda(TensorIterator& iter) {
3232

3333
void exp_kernel_cuda(TensorIterator& iter) {
3434
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exp_cuda", [&]() {
35-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "exp_cuda", [&] {
36-
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
37-
return ::exp(a);
38-
});
35+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
36+
return ::exp(a);
3937
});
4038
});
4139
}
@@ -140,10 +138,8 @@ void logit_kernel_cuda(TensorIterator& iter, Scalar eps_scalar) {
140138

141139
void erf_kernel_cuda(TensorIterator& iter) {
142140
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "erf_cuda", [&]() {
143-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "erf_cuda", [&] {
144-
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
145-
return ::erf(a);
146-
});
141+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
142+
return ::erf(a);
147143
});
148144
});
149145
}

test/test_torch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20114,18 +20114,18 @@ def inner(self, device, dtype):
2011420114
('acosh', '', lambda t, d: _small_3d(t, d) + 1, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types2),
2011520115
('asinh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types2),
2011620116
('atanh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types2),
20117-
('erf', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types2, [torch.bfloat16]),
20117+
('erf', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
2011820118
('erfc', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
2011920119
('erfinv', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
20120-
('exp', '', _small_3d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _float_types),
20120+
('exp', '', _small_3d, lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes()),
2012120121
('exp', 'small', lambda t, d: _small_3d(t, d).clamp(-1, 1),
20122-
lambda t, d: [], 1e-2, 1e-2, 1e-5, _float_types2, [torch.bfloat16]),
20122+
lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
2012320123
('expm1', '', _small_3d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _float_types),
2012420124
('expm1', 'small', lambda t, d: _small_3d(t, d).clamp(-1, 1),
2012520125
lambda t, d: [], 1e-2, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
20126-
('rad2deg', '', _small_3d, lambda t, d: [], 1e-1, 1e-0, 1e-5, _float_types2, [torch.bfloat16]),
20127-
('deg2rad', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, _float_types2, [torch.bfloat16]),
20128-
('reciprocal', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, _float_types2, [torch.bfloat16]),
20126+
('rad2deg', '', _small_3d, lambda t, d: [], 1e-1, 1e-0, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
20127+
('deg2rad', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
20128+
('reciprocal', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
2012920129
('floor', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
2013020130
('frac', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
2013120131
('round', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),

0 commit comments

Comments
 (0)