Skip to content

Commit b79bac0

Browse files
CaoEpytorchmergebot
authored andcommitted
Make the data types of output and input consistenst for batchnorm (#84410)
The model TTS will crash due to the issue:: when input of BN is not contiguous and the data type of input is different with that of parameters, BN will raise error `RuntimeError: !needs_dynamic_casting<func_t>::check(iter) INTERNAL ASSERT FAILED at "xxx/pytorch/aten/src/ATen/native/cpu/Loops.h":311, please report a bug to PyTorch`. Make the data types of output and input consistenst for batchnorm to fix the issue. Pull Request resolved: #84410 Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/malfet
1 parent c2f29e7 commit b79bac0

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

aten/src/ATen/native/Normalization.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
142142
.check_all_same_dtype(false)
143143
.promote_inputs_to_common_dtype(false)
144144
.build();
145-
146-
cpu_kernel(iter, [=](scalar_t input, param_t mean, param_t invstd, param_t weight, param_t bias) {
145+
cpu_kernel(iter, [=](scalar_t input, param_t mean, param_t invstd, param_t weight, param_t bias) -> scalar_t {
147146
return ((input - mean) * invstd) * weight + bias;
148147
});
149148
return std::make_tuple(output, save_mean, save_invstd);

test/test_nn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8883,6 +8883,15 @@ def test_batchnorm_non_contig_cpu(self):
88838883
self.assertTrue(ref_out.is_contiguous())
88848884
self.assertEqual(out, ref_out)
88858885

8886+
input_bf = torch.arange(24, dtype=torch.bfloat16).reshape(1, 3, 2, 4)
8887+
input_bf = input_bf.permute(0, 2, 1, 3)
8888+
input_f = input_bf.float()
8889+
bn_mix = torch.nn.BatchNorm2d(2).float().eval()
8890+
ref_bn_f = deepcopy(bn_mix)
8891+
out_bf = bn_mix(input_bf)
8892+
ref_out_bf = ref_bn_f(input_f)
8893+
self.assertEqual(ref_out_bf, out_bf.float(), atol=0.05, rtol=0.05)
8894+
88868895
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
88878896
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
88888897
def test_batchnorm_cudnn_nhwc(self):

0 commit comments

Comments
 (0)