Skip to content

Commit 7bd8a69

Browse files
zasdfgbnmfacebook-github-bot
authored andcommitted
CUDA BFloat div, addcdiv, addcmul, mean, var (#44758)
Summary: Pull Request resolved: #44758 Reviewed By: mruberry Differential Revision: D23752317 Pulled By: ngimel fbshipit-source-id: 77992cf991f4e2b4b6839de73ea7e6ce2e1061c6
1 parent f175830 commit 7bd8a69

File tree

3 files changed

+19
-26
lines changed

3 files changed

+19
-26
lines changed

aten/src/ATen/native/cuda/PointwiseOpsKernel.cu

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,18 @@ namespace at { namespace native {
1010

1111
void addcmul_cuda_kernel(TensorIterator& iter, Scalar value) {
1212
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "addcmul_cuda", [&]() {
13-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "addcmul_cuda", [&] {
14-
auto alpha = value.to<scalar_t>();
15-
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
16-
return a + alpha * b * c;
17-
});
13+
auto alpha = value.to<scalar_t>();
14+
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
15+
return a + alpha * b * c;
1816
});
1917
});
2018
}
2119

2220
void addcdiv_cuda_kernel(TensorIterator& iter, Scalar value) {
2321
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "addcdiv_cuda", [&]() {
24-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "addcdiv_cuda", [&] {
25-
auto alpha = value.to<scalar_t>();
26-
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
27-
return a + alpha * (b / c);
28-
});
22+
auto alpha = value.to<scalar_t>();
23+
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
24+
return a + alpha * (b / c);
2925
});
3026
});
3127
}

aten/src/ATen/native/cuda/ReduceMomentKernel.cu

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ void std_var_kernel_impl<at::BFloat16>(TensorIterator& iter, bool unbiased, bool
3030

3131
static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_sqrt) {
3232
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "std_cuda", [&]() {
33-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "std_cuda", [&] {
34-
std_var_kernel_impl<scalar_t>(iter, unbiased, take_sqrt);
35-
});
33+
std_var_kernel_impl<scalar_t>(iter, unbiased, take_sqrt);
3634
});
3735
}
3836

@@ -49,14 +47,12 @@ static void mean_kernel_cuda(TensorIterator& iter) {
4947
// type promotion that does cast and reduction in a single kernel
5048
return mean_kernel_impl<at::Half, float, float>(iter);
5149
}
52-
#ifdef __HIP_PLATFORM_HCC__
5350
else if(iter.dtype() == kBFloat16) {
5451
return mean_kernel_impl<at::BFloat16, float>(iter);
5552
} else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) {
5653
// type promotion that does cast and reduction in a single kernel
5754
return mean_kernel_impl<at::BFloat16, float, float>(iter);
5855
}
59-
#endif
6056
AT_DISPATCH_ALL_TYPES(iter.dtype(), "mean_cuda", [&]() {
6157
mean_kernel_impl<scalar_t>(iter);
6258
});

test/test_torch.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19787,20 +19787,20 @@ def inner(self, device, dtype):
1978719787
('mul', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d)], 1e-2),
1978819788
('mul', 'scalar', _small_0d, lambda t, d: [_small_0d(torch.int32, d)], 1e-2),
1978919789
('div', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1,
19790-
1e-1, 1e-5, _float_types2),
19790+
1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
1979119791
('div', 'tensor', _small_3d,
1979219792
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-1,
19793-
1e-1, 1e-5, _float_types2),
19793+
1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
1979419794
('true_divide', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1,
1979519795
1e-5, 1e-5, _types, _cpu_types, False),
1979619796
('true_divide', 'with_inplace', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1,
19797-
1e-1, 1e-5, _float_types2),
19797+
1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
1979819798
('true_divide', 'tensor', _small_3d,
1979919799
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-1,
1980019800
1e-5, 1e-5, _types, _cpu_types, False),
1980119801
('true_divide', 'tensor_with_inplace', _small_3d,
1980219802
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-1,
19803-
1e-1, 1e-5, _float_types2),
19803+
1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
1980419804
('floor_divide', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1, 1e-5, 1e-5, _types),
1980519805
('floor_divide', 'tensor', _small_3d,
1980619806
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1, 1e-5, 1e-5, _types),
@@ -19834,15 +19834,16 @@ def inner(self, device, dtype):
1983419834
('addcdiv', '', _small_2d,
1983519835
lambda t, d: [_small_2d(t, d),
1983619836
_small_2d(t, d, has_zeros=False)], 1, 1, 1e-3,
19837-
_float_types2, _cpu_types, True),
19837+
torch.testing.get_all_fp_dtypes(), _cpu_types, True),
1983819838
('addcdiv', 'scalar', _small_2d,
1983919839
lambda t, d: [_number(2.8, 1, t), _small_2d(t, d),
1984019840
_small_2d(t, d, has_zeros=False)], 1, 1e-5, 1e-3,
1984119841
_float_types, _cpu_types, True),
19842-
('addcmul', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], 1e-2, 1e-1, 1e-3, _types2),
19842+
('addcmul', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], 1e-2, 1e-1, 1e-3,
19843+
torch.testing.get_all_dtypes(include_complex=False, include_bool=False)),
1984319844
('addcmul', 'scalar', _small_3d,
1984419845
lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], 1e-2,
19845-
1e-1, 1e-5, _types2, _cpu_types, True,
19846+
1e-1, 1e-5, torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, True,
1984619847
[_wrap_maybe_warns("This overload of addcmul_? is deprecated")]),
1984719848
('addmm', '', _medium_2d, lambda t, d: [_medium_2d(t, d), _medium_2d(t, d)],
1984819849
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True, [tf32_on_and_off(0.005)], 0, True),
@@ -19957,9 +19958,9 @@ def inner(self, device, dtype):
1995719958
1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
1995819959
('minimum', '', _medium_2d, lambda t, d: [_medium_2d(t, d)],
1995919960
1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
19960-
('mean', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types2, _cpu_types, False),
19961-
('mean', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, _float_types2, _cpu_types, False),
19962-
('mean', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-2, 1e-2, _float_types2, _cpu_types, False),
19961+
('mean', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False),
19962+
('mean', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False),
19963+
('mean', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-2, 1e-2, torch.testing.get_all_fp_dtypes(), _cpu_types, False),
1996319964
# Double here because the CPU result will be wrong otherwise
1996419965
('mean', '64bit_indexing', _giant_1d, lambda t, d: [],
1996519966
1e-3, 1e-5, 1e-5, [torch.double], _cpu_types, False, [slowTest]),
@@ -19983,7 +19984,7 @@ def inner(self, device, dtype):
1998319984
('std', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False),
1998419985
('var', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False),
1998519986
('var', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False),
19986-
('var', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, _float_types2, _cpu_types, False),
19987+
('var', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False),
1998719988
('ndimension', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
1998819989
('nelement', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
1998919990
('numel', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),

0 commit comments

Comments
 (0)