Skip to content

Commit fb53c18

Browse files
bddppqfacebook-github-bot
authored andcommitted
Add aten mkldnn batch_norm operator
Summary: Pull Request resolved: #19206 Reviewed By: dzhulgakov Differential Revision: D14887205 fbshipit-source-id: ea00c9e3205c449d08ab29535309164f951aab95
1 parent 4864000 commit fb53c18

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/Config.h>
3+
#include <ATen/NativeFunctions.h>
4+
#include <tuple>
5+
6+
#if !AT_MKLDNN_ENABLED()
7+
8+
namespace at {
9+
namespace native {
10+
11+
std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
12+
const Tensor& self,
13+
const Tensor& weight,
14+
const Tensor& bias,
15+
const Tensor& running_mean,
16+
const Tensor& running_var,
17+
bool train,
18+
double momentum,
19+
double eps) {
20+
AT_ERROR("mkldnn_batch_norm: ATen not compiled with MKLDNN support");
21+
}
22+
23+
} // namespace native
24+
} // namespace at
25+
26+
#else // AT_MKLDNN_EBABLED
27+
28+
#include <ATen/native/mkldnn/MKLDNNCommon.h>
29+
30+
namespace at {
31+
namespace native {
32+
33+
std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
34+
const Tensor& input,
35+
const Tensor& weight,
36+
const Tensor& bias,
37+
const Tensor& running_mean,
38+
const Tensor& running_var,
39+
bool train,
40+
double momentum,
41+
double eps) {
42+
ideep::tensor& x = itensor_from_mkldnn(input);
43+
ideep::tensor& w = itensor_from_mkldnn(weight);
44+
ideep::tensor& b = itensor_from_mkldnn(bias);
45+
ideep::tensor& m = itensor_from_mkldnn(running_mean);
46+
ideep::tensor& v = itensor_from_mkldnn(running_var);
47+
48+
ideep::tensor y;
49+
50+
if (train) {
51+
// TODO: support training
52+
AT_ERROR("mkldnn_batch_norm: mkldnn training is not supported in yet.");
53+
54+
// ideep::tensor saved_mean;
55+
// ideep::tensor saved_var;
56+
// ideep::batch_normalization_forward_training::compute<AllocForMKLDNN>(
57+
// x, w, b, y, saved_mean, saved_var, m, v, momentum, eps);
58+
// return std::make_tuple(
59+
// new_with_itensor_mkldnn(std::move(y), input.options()),
60+
// new_with_itensor_mkldnn(std::move(saved_mean), input.options()),
61+
// new_with_itensor_mkldnn(std::move(saved_var), input.options()));
62+
} else {
63+
AT_ASSERTM(input.dim() == 4 || input.dim() == 5,
64+
"mkldnn_batch_norm: currently mkldnn only support 2d and 3d batchnorm");
65+
ideep::batch_normalization_forward_inference::compute<AllocForMKLDNN>(
66+
x, m, v, w, b, y, eps);
67+
return std::make_tuple(
68+
new_with_itensor_mkldnn(std::move(y), input.options()),
69+
new_with_itensor_mkldnn(ideep::tensor{}, input.options()),
70+
new_with_itensor_mkldnn(ideep::tensor{}, input.options()));
71+
}
72+
}
73+
74+
} // namespace native
75+
} // namespace at
76+
77+
#endif // AT_MKLDNN_EBABLED

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,7 @@
13071307
dispatch:
13081308
CPU: batch_norm_cpu
13091309
CUDA: batch_norm_cuda
1310+
MkldnnCPU: mkldnn_batch_norm
13101311

13111312
- func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)
13121313
dispatch:

test/test_mkldnn.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,19 @@ def test_avg_pool2d(self):
148148
avg_pool2d(x),
149149
avg_pool2d(x.to_mkldnn()).to_dense())
150150

151+
def test_batch_norm2d(self):
152+
N = torch.randint(3, 10, (1,)).item()
153+
C = torch.randint(3, 100, (1,)).item()
154+
x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
155+
156+
# TODO: support training
157+
for train in [False]:
158+
bn = torch.nn.BatchNorm2d(C).float().train(train)
159+
mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn))
160+
self.assertEqual(
161+
bn(x),
162+
mkldnn_bn(x.to_mkldnn()).to_dense())
163+
151164

152165
if __name__ == '__main__':
153166
run_tests()

0 commit comments

Comments
 (0)