Skip to content

Commit 4534b4c

Browse files
committed
add aten mkldnn batchnorm operator
1 parent 5e1f0b2 commit 4534b4c

File tree

6 files changed

+258
-0
lines changed

6 files changed

+258
-0
lines changed

aten/src/ATen/core/aten_interned_strings.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,8 @@ _(aten, mkldnn_convolution) \
465465
_(aten, mkldnn_convolution_backward) \
466466
_(aten, mkldnn_convolution_backward_input) \
467467
_(aten, mkldnn_convolution_backward_weights) \
468+
_(aten, mkldnn_batch_norm) \
469+
_(aten, mkldnn_batch_norm_backward) \
468470
_(aten, mm) \
469471
_(aten, mode) \
470472
_(aten, mse_loss) \

aten/src/ATen/native/Normalization.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,20 @@ std::tuple<Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
319319
std::make_tuple(2));
320320
}
321321

322+
bool use_mkldnn = (input.type().backend() == at::Backend::CPU
323+
&& input.type().scalarType() == at::kFloat
324+
&& (input.ndimension() == 4 || input.ndimension() == 5)
325+
&& (weight.defined() && bias.defined())
326+
&& ((running_mean.defined() && running_var.defined())
327+
|| (!running_mean.defined() && !running_var.defined()))
328+
);
329+
330+
if (use_mkldnn) {
331+
return std::tuple_cat(
332+
at::mkldnn_batch_norm(
333+
input, weight, bias,running_mean, running_var, training, momentum, eps),
334+
std::make_tuple(3));
335+
}
322336
return std::tuple_cat(
323337
at::native_batch_norm(
324338
input, weight, bias, running_mean, running_var, training, momentum, eps),
@@ -337,6 +351,8 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
337351
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
338352
} else if (impl_index == 2) {
339353
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
354+
} else if (impl_index == 3) {
355+
return at::mkldnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
340356
}
341357
AT_ASSERTM(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
342358
}
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/NativeFunctions.h>
3+
#include <ATen/Config.h>
4+
5+
#if !AT_MKLDNN_ENABLED()
6+
7+
namespace at { namespace native {
8+
9+
std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(const Tensor& input_,
10+
const Tensor& weight_, const Tensor& bias_, const Tensor& running_mean_,
11+
const Tensor& running_var_, bool training, double exponential_average_factor, double epsilon){
12+
AT_ERROR("mkldnn_batch_norm: ATen not compiled with MKLDNN support");
13+
}
14+
15+
std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(
16+
const Tensor& input_, const Tensor& grad_output, const Tensor& weight_,
17+
const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean,
18+
const Tensor& save_var, double epsilon) {
19+
AT_ERROR("mkldnn_batch_norm_backward: ATen not compiled with MKLDNN support");
20+
}
21+
22+
}} // namespace at::native
23+
24+
#else // AT_MKLDNN_EBABLED
25+
26+
#include <ATen/mkldnn/Runtime.h>
27+
#include <ATen/native/mkldnn/MKLDNNCommon.h>
28+
29+
namespace at { namespace native {
30+
31+
namespace {
32+
33+
constexpr int max_dim = 3;
34+
35+
struct BatchNormParams {
36+
int64_t dim;
37+
int64_t input_size[2 + max_dim];
38+
double epsilon;
39+
bool training;
40+
bool use_running_stat;
41+
};
42+
43+
void setBatchNormParams(BatchNormParams* params, const Tensor& input,
44+
double epsilon, bool training, bool use_running_stat) {
45+
46+
memset(params, 0, sizeof(BatchNormParams));
47+
48+
params->dim = input.dim();
49+
for (int64_t i = 0; i < params->dim; ++i) {
50+
params->input_size[i] = input.size(i);
51+
}
52+
53+
params->epsilon = epsilon;
54+
params->training = training;
55+
params->use_running_stat = use_running_stat;
56+
}
57+
58+
struct BatchNormArgs {
59+
BatchNormParams params;
60+
ideep::tensor::dims input_tz;
61+
62+
BatchNormArgs(const Tensor& input, const Tensor& running_mean,
63+
const Tensor& running_var, bool training, double epsilon) {
64+
65+
bool use_running_stat = (running_mean.defined() && running_var.defined());
66+
setBatchNormParams(&params, input, epsilon, training, use_running_stat);
67+
68+
for (int64_t i = 0; i < params.dim; ++i) {
69+
input_tz.push_back(params.input_size[i]);
70+
}
71+
}
72+
};
73+
74+
} // namespace
75+
76+
std::tuple<Tensor, Tensor, Tensor> _ideep_batch_norm(const Tensor& input,
77+
const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
78+
const Tensor& running_var, bool training, double exponential_average_factor, double epsilon) {
79+
80+
bool use_running_stat = (running_mean.defined() && running_var.defined());
81+
const ideep::tensor& src_ = itensor_from_mkldnn(input);
82+
const ideep::tensor& weight_ = itensor_from_mkldnn(weight);
83+
const ideep::tensor& bias_ = itensor_from_mkldnn(bias);
84+
ideep::tensor dst_, save_mean_, save_var_;
85+
86+
if (training) {
87+
if (use_running_stat) {
88+
ideep::tensor running_mean_ = itensor_from_mkldnn(running_mean);
89+
ideep::tensor running_var_ = itensor_from_mkldnn(running_var);
90+
ideep::batch_normalization_forward_training::compute(
91+
src_, weight_, bias_, dst_, save_mean_, save_var_,
92+
running_mean_, running_var_, exponential_average_factor, epsilon);
93+
} else {
94+
ideep::batch_normalization_forward_training::compute(src_, weight_, bias_, dst_,
95+
save_mean_, save_var_, exponential_average_factor, epsilon);
96+
}
97+
} else {
98+
if (use_running_stat) {
99+
ideep::tensor running_mean_ = itensor_from_mkldnn(running_mean);
100+
ideep::tensor running_var_ = itensor_from_mkldnn(running_var);
101+
ideep::batch_normalization_forward_inference::compute(
102+
src_, running_mean_, running_var_, weight_, bias_, dst_, epsilon);
103+
} else {
104+
ideep::batch_normalization_forward_inference::compute(src_, weight_, bias_, dst_, epsilon);
105+
}
106+
}
107+
108+
auto output = new_with_itensor_mkldnn(std::move(dst_), input.options());
109+
auto mean_output = new_with_itensor_mkldnn(std::move(save_mean_), input.options());
110+
auto var_output = new_with_itensor_mkldnn(std::move(save_var_), input.options());
111+
return std::tuple<Tensor, Tensor, Tensor>{output, mean_output, var_output};
112+
}
113+
114+
std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(const Tensor& input_,
115+
const Tensor& weight_, const Tensor& bias_, const Tensor& running_mean_,
116+
const Tensor& running_var_, bool training, double exponential_average_factor, double epsilon) {
117+
118+
if (input_.type_id() == MkldnnCPUTensorId()) {
119+
return _ideep_batch_norm(input_, weight_, bias_, running_mean_, running_var_, training, exponential_average_factor, epsilon);
120+
}
121+
122+
auto input = input_.contiguous();
123+
auto weight = weight_.contiguous();
124+
auto bias = bias_.contiguous();
125+
auto running_mean = running_mean_.defined() ? running_mean_.contiguous() : running_mean_;
126+
auto running_var = running_var_.defined() ? running_var_.contiguous() : running_var_;
127+
128+
auto output = at::empty_like(input);
129+
130+
int32_t ic = input.size(1);
131+
auto save_mean = at::empty({ic}, weight.options());
132+
auto save_var = at::empty({ic}, weight.options());
133+
134+
BatchNormArgs args(input, running_mean, running_var, training, epsilon);
135+
136+
auto type_ = ideep::tensor::data_type::f32;
137+
ideep::tensor::descriptor src_desc(args.input_tz, type_);
138+
ideep::tensor::descriptor statistic_desc_({ic}, type_);
139+
140+
ideep::tensor src_, dst, scale_, shift_, mean, var, run_mean_, run_var_;
141+
src_.init(src_desc, input.data_ptr());
142+
dst.init(src_desc, output.data_ptr());
143+
144+
scale_.init(statistic_desc_, weight.data_ptr());
145+
shift_.init(statistic_desc_, bias.data_ptr());
146+
147+
mean.init(statistic_desc_, save_mean.data_ptr());
148+
var.init(statistic_desc_, save_var.data_ptr());
149+
150+
ideep::tensor dst_, save_mean_, save_var_;
151+
152+
if (training) {
153+
if (args.params.use_running_stat) {
154+
run_mean_.init(statistic_desc_, running_mean.data_ptr());
155+
run_var_.init(statistic_desc_, running_var.data_ptr());
156+
ideep::batch_normalization_forward_training::compute(
157+
src_, scale_, shift_, dst_, save_mean_, save_var_,
158+
run_mean_, run_var_, exponential_average_factor, epsilon);
159+
} else {
160+
ideep::batch_normalization_forward_training::compute(src_, scale_, shift_, dst_,
161+
save_mean_, save_var_, exponential_average_factor, epsilon);
162+
}
163+
} else {
164+
if (args.params.use_running_stat) {
165+
run_mean_.init(statistic_desc_, running_mean.data_ptr());
166+
run_var_.init(statistic_desc_, running_var.data_ptr());
167+
ideep::batch_normalization_forward_inference::compute(
168+
src_, run_mean_, run_var_, scale_, shift_, dst_, epsilon);
169+
} else {
170+
ideep::batch_normalization_forward_inference::compute(src_, scale_, shift_, dst_, epsilon);
171+
}
172+
}
173+
174+
ideep::reorder::compute(dst_, dst);
175+
ideep::reorder::compute(save_mean_, mean);
176+
ideep::reorder::compute(save_var_, var);
177+
178+
return std::tuple<Tensor, Tensor, Tensor>{output, save_mean, save_var};
179+
}
180+
181+
std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(const Tensor& input_,
182+
const Tensor& grad_output, const Tensor& weight_, const Tensor& running_mean_,
183+
const Tensor& running_var_, const Tensor& save_mean, const Tensor& save_var, double epsilon) {
184+
185+
auto input = input_.contiguous();
186+
auto weight = weight_.contiguous();
187+
188+
auto grad_input = at::empty_like(input);
189+
auto grad_weight = at::empty_like(weight);
190+
auto grad_bias = at::empty_like(weight);
191+
192+
int32_t ic = input.size(1);
193+
194+
auto type_ = ideep::tensor::data_type::f32;
195+
BatchNormArgs args(input, running_mean_, running_var_, true, epsilon);
196+
197+
ideep::tensor::descriptor src_desc(args.input_tz, type_);
198+
ideep::tensor::descriptor statistic_desc_({ic}, type_);
199+
200+
ideep::tensor src_, grady_, gradx, gradw, gradb, save_mean_, save_var_, scale_;
201+
202+
src_.init(src_desc, input.data_ptr());
203+
grady_.init(src_desc, grad_output.data_ptr());
204+
gradx.init(src_desc, grad_input.data_ptr());
205+
206+
gradw.init(statistic_desc_, grad_weight.data_ptr());
207+
gradb.init(statistic_desc_, grad_bias.data_ptr());
208+
209+
save_mean_.init(statistic_desc_, save_mean.data_ptr());
210+
save_var_.init(statistic_desc_, save_var.data_ptr());
211+
scale_.init(statistic_desc_, weight.data_ptr());
212+
213+
ideep::tensor gradx_, gradw_, gradb_;
214+
ideep::batch_normalization_backward::compute( src_, save_mean_, save_var_, grady_, scale_, gradx_, gradw_, gradb_, epsilon);
215+
216+
ideep::reorder::compute(gradx_, gradx);
217+
ideep::reorder::compute(gradw_, gradw);
218+
ideep::reorder::compute(gradb_, gradb);
219+
220+
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
221+
}
222+
223+
}} // namespace at::native
224+
#endif

aten/src/ATen/native/mkldnn/MKLDNNCommon.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
namespace at { namespace native {
1010

11+
// Construct MKL-DNN tensor given an ideep tensor
12+
Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options);
13+
1114
// Construct MKL-DNN tensor given `sizes` for allocation
1215
Tensor new_with_sizes_mkldnn(IntArrayRef sizes, const TensorOptions& options);
1316

aten/src/ATen/native/native_functions.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,12 @@
11691169

11701170
- func: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
11711171

1172+
- func: mkldnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
1173+
matches_jit_signature: True
1174+
1175+
- func: mkldnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)
1176+
matches_jit_signature: True
1177+
11721178
- func: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
11731179
dispatch:
11741180
CUDA: miopen_batch_norm

tools/autograd/derivatives.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,6 +1448,13 @@
14481448
- name: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask)
14491449
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, false, false, false, grad_input_mask)
14501450

1451+
- name: mkldnn_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon)
1452+
input, weight, bias: "training ? mkldnn_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)"
1453+
1454+
- name: mkldnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_var, double epsilon)
1455+
save_mean: not_implemented("mkldnn_batch_norm_backward save_mean")
1456+
save_var: not_implemented("mkldnn_batch_norm_backward save_var")
1457+
input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask)
14511458
# fft
14521459
- name: _fft_with_size(Tensor self, int64_t signal_ndim, bool complex_input, bool complex_output, bool inverse, IntArrayRef checked_signal_sizes, bool normalized, bool onesided, IntArrayRef output_sizes)
14531460
self: fft_backward(self, grad, signal_ndim, complex_input, complex_output, inverse, checked_signal_sizes, normalized, onesided, output_sizes)

0 commit comments

Comments
 (0)