Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
.check_all_same_dtype(false)
.promote_inputs_to_common_dtype(false)
.build();

cpu_kernel(iter, [=](scalar_t input, param_t mean, param_t invstd, param_t weight, param_t bias) {
cpu_kernel(iter, [=](scalar_t input, param_t mean, param_t invstd, param_t weight, param_t bias) -> scalar_t {
return ((input - mean) * invstd) * weight + bias;
});
return std::make_tuple(output, save_mean, save_invstd);
Expand Down
9 changes: 9 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8843,6 +8843,15 @@ def test_batchnorm_non_contig_cpu(self):
self.assertTrue(ref_out.is_contiguous())
self.assertEqual(out, ref_out)

input_bf = torch.arange(24, dtype=torch.bfloat16).reshape(1, 3, 2, 4)
input_bf = input_bf.permute(0, 2, 1, 3)
input_f = input_bf.float()
bn_mix = torch.nn.BatchNorm2d(2).float().eval()
ref_bn_f = deepcopy(bn_mix)
out_bf = bn_mix(input_bf)
ref_out_bf = ref_bn_f(input_f)
self.assertEqual(ref_out_bf, out_bf.float(), atol=0.05, rtol=0.05)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
def test_batchnorm_cudnn_nhwc(self):
Expand Down