Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions aten/src/ATen/native/cuda/PointwiseOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,18 @@ namespace at { namespace native {

void addcmul_cuda_kernel(TensorIterator& iter, Scalar value) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "addcmul_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "addcmul_cuda", [&] {
auto alpha = value.to<scalar_t>();
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
return a + alpha * b * c;
});
auto alpha = value.to<scalar_t>();
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
return a + alpha * b * c;
});
});
}

void addcdiv_cuda_kernel(TensorIterator& iter, Scalar value) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "addcdiv_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "addcdiv_cuda", [&] {
auto alpha = value.to<scalar_t>();
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
return a + alpha * (b / c);
});
auto alpha = value.to<scalar_t>();
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
return a + alpha * (b / c);
});
});
}
Expand Down
6 changes: 1 addition & 5 deletions aten/src/ATen/native/cuda/ReduceMomentKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ void std_var_kernel_impl<at::BFloat16>(TensorIterator& iter, bool unbiased, bool

static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_sqrt) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "std_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "std_cuda", [&] {
std_var_kernel_impl<scalar_t>(iter, unbiased, take_sqrt);
});
std_var_kernel_impl<scalar_t>(iter, unbiased, take_sqrt);
});
}

Expand All @@ -49,14 +47,12 @@ static void mean_kernel_cuda(TensorIterator& iter) {
// type promotion that does cast and reduction in a single kernel
return mean_kernel_impl<at::Half, float, float>(iter);
}
#ifdef __HIP_PLATFORM_HCC__
else if(iter.dtype() == kBFloat16) {
return mean_kernel_impl<at::BFloat16, float>(iter);
} else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) {
// type promotion that does cast and reduction in a single kernel
return mean_kernel_impl<at::BFloat16, float, float>(iter);
}
#endif
AT_DISPATCH_ALL_TYPES(iter.dtype(), "mean_cuda", [&]() {
mean_kernel_impl<scalar_t>(iter);
});
Expand Down
23 changes: 12 additions & 11 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19787,20 +19787,20 @@ def inner(self, device, dtype):
('mul', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d)], 1e-2),
('mul', 'scalar', _small_0d, lambda t, d: [_small_0d(torch.int32, d)], 1e-2),
('div', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1,
1e-1, 1e-5, _float_types2),
1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
('div', 'tensor', _small_3d,
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-1,
1e-1, 1e-5, _float_types2),
1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
('true_divide', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1,
1e-5, 1e-5, _types, _cpu_types, False),
('true_divide', 'with_inplace', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1,
1e-1, 1e-5, _float_types2),
1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
('true_divide', 'tensor', _small_3d,
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-1,
1e-5, 1e-5, _types, _cpu_types, False),
('true_divide', 'tensor_with_inplace', _small_3d,
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-1,
1e-1, 1e-5, _float_types2),
1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
('floor_divide', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1, 1e-5, 1e-5, _types),
('floor_divide', 'tensor', _small_3d,
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1, 1e-5, 1e-5, _types),
Expand Down Expand Up @@ -19834,15 +19834,16 @@ def inner(self, device, dtype):
('addcdiv', '', _small_2d,
lambda t, d: [_small_2d(t, d),
_small_2d(t, d, has_zeros=False)], 1, 1, 1e-3,
_float_types2, _cpu_types, True),
torch.testing.get_all_fp_dtypes(), _cpu_types, True),
('addcdiv', 'scalar', _small_2d,
lambda t, d: [_number(2.8, 1, t), _small_2d(t, d),
_small_2d(t, d, has_zeros=False)], 1, 1e-5, 1e-3,
_float_types, _cpu_types, True),
('addcmul', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], 1e-2, 1e-1, 1e-3, _types2),
('addcmul', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], 1e-2, 1e-1, 1e-3,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False)),
('addcmul', 'scalar', _small_3d,
lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], 1e-2,
1e-1, 1e-5, _types2, _cpu_types, True,
1e-1, 1e-5, torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, True,
[_wrap_maybe_warns("This overload of addcmul_? is deprecated")]),
('addmm', '', _medium_2d, lambda t, d: [_medium_2d(t, d), _medium_2d(t, d)],
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True, [tf32_on_and_off(0.005)], 0, True),
Expand Down Expand Up @@ -19957,9 +19958,9 @@ def inner(self, device, dtype):
1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('minimum', '', _medium_2d, lambda t, d: [_medium_2d(t, d)],
1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('mean', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types2, _cpu_types, False),
('mean', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, _float_types2, _cpu_types, False),
('mean', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-2, 1e-2, _float_types2, _cpu_types, False),
('mean', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False),
('mean', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False),
('mean', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-2, 1e-2, torch.testing.get_all_fp_dtypes(), _cpu_types, False),
# Double here because the CPU result will be wrong otherwise
('mean', '64bit_indexing', _giant_1d, lambda t, d: [],
1e-3, 1e-5, 1e-5, [torch.double], _cpu_types, False, [slowTest]),
Expand All @@ -19983,7 +19984,7 @@ def inner(self, device, dtype):
('std', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False),
('var', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False),
('var', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False),
('var', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, _float_types2, _cpu_types, False),
('var', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False),
('ndimension', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('nelement', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('numel', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
Expand Down