Skip to content

Commit e01344f

Browse files
committed
Add aten mkldnn conv2d backward operator
1 parent bed1d7d commit e01344f

File tree

9 files changed

+139
-240
lines changed

9 files changed

+139
-240
lines changed

aten/src/ATen/native/mkldnn/Conv.cpp

Lines changed: 104 additions & 234 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ at::Tensor mkldnn_convolution(
1414

1515
at::Tensor mkldnn_convolution_backward_input(
1616
IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
17-
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
17+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) {
1818
AT_ERROR("mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support");
1919
}
2020

@@ -118,6 +118,70 @@ ideep::tensor _mkldnn_conv2d(
118118
return y;
119119
}
120120

121+
ideep::tensor _mkldnn_conv2d_backward_input(
122+
at::IntArrayRef input_sizes,
123+
const ideep::tensor& grady,
124+
const ideep::tensor& w,
125+
at::IntArrayRef padding,
126+
at::IntArrayRef stride,
127+
at::IntArrayRef dilation,
128+
int64_t groups) {
129+
ideep::tensor gradx;
130+
ideep::convolution_backward_data::compute<AllocForMKLDNN>(
131+
grady,
132+
w,
133+
{input_sizes.cbegin(), input_sizes.cend()},
134+
gradx,
135+
{stride.begin(), stride.end()},
136+
{dilation.begin(), dilation.end()},
137+
{padding.begin(), padding.end()},
138+
{padding.begin(), padding.end()},
139+
groups,
140+
ideep::algorithm::convolution_direct);
141+
142+
return gradx;
143+
}
144+
145+
std::tuple<ideep::tensor, ideep::tensor> _mkldnn_conv2d_backward_weights(
146+
at::IntArrayRef weight_sizes,
147+
const ideep::tensor& grady,
148+
const ideep::tensor& x,
149+
at::IntArrayRef padding,
150+
at::IntArrayRef stride,
151+
at::IntArrayRef dilation,
152+
int64_t groups,
153+
bool bias_defined) {
154+
ideep::tensor gradw, gradb;
155+
if (bias_defined) {
156+
ideep::convolution_backward_weights::compute<AllocForMKLDNN>(
157+
x,
158+
grady,
159+
{weight_sizes.cbegin(), weight_sizes.cend()},
160+
gradw,
161+
gradb,
162+
{stride.begin(), stride.end()},
163+
{dilation.begin(), dilation.end()},
164+
{padding.begin(), padding.end()},
165+
{padding.begin(), padding.end()},
166+
groups,
167+
ideep::algorithm::convolution_direct);
168+
} else {
169+
ideep::convolution_backward_weights::compute<AllocForMKLDNN>(
170+
x,
171+
grady,
172+
{weight_sizes.cbegin(), weight_sizes.cend()},
173+
gradw,
174+
{stride.begin(), stride.end()},
175+
{dilation.begin(), dilation.end()},
176+
{padding.begin(), padding.end()},
177+
{padding.begin(), padding.end()},
178+
groups,
179+
ideep::algorithm::convolution_direct);
180+
}
181+
182+
return std::tuple<ideep::tensor, ideep::tensor>{gradw, gradb};
183+
}
184+
121185
at::Tensor mkldnn_convolution(
122186
const at::Tensor& input,
123187
const at::Tensor& weight,
@@ -152,259 +216,65 @@ at::Tensor mkldnn_convolution(
152216

153217
Tensor mkldnn_convolution_backward_input(
154218
IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
155-
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
156-
{
157-
auto grad_input = at::empty(input_size, grad_output.options());
158-
159-
auto cpu_engine = CpuEngine::Instance().get_engine();
160-
161-
int32_t g = groups;
162-
163-
int32_t n = grad_input.size(0);
164-
int32_t ic = grad_input.size(1);
165-
int32_t ih = grad_input.size(2);
166-
int32_t iw = grad_input.size(3);
167-
168-
int32_t oc = grad_output.size(1);
169-
int32_t oh = grad_output.size(2);
170-
int32_t ow = grad_output.size(3);
171-
172-
int32_t kh = weight.size(2);
173-
int32_t kw = weight.size(3);
174-
175-
int32_t sh = stride[0];
176-
int32_t sw = stride[1];
177-
int32_t ph = padding[0];
178-
int32_t pw = padding[1];
179-
180-
auto data_t = memory::data_type::f32;
181-
auto format_any = memory::format::any;
182-
auto format_nchw = memory::format::nchw;
183-
auto format_weight = (g!= 1) ? memory::format::goihw : memory::format::oihw;
184-
185-
memory::dims input_tz = {n, ic, ih, iw};
186-
memory::dims weight_tz = (g!= 1) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
187-
memory::dims bias_tz = {oc};
188-
memory::dims output_tz = {n, oc, oh, ow};
189-
memory::dims _stride = {sh, sw};
190-
memory::dims _padding = {ph, pw};
191-
192-
auto input_md = memory::desc({input_tz}, data_t, format_any);
193-
auto weight_md = memory::desc({weight_tz}, data_t, format_any);
194-
auto bias_md = memory::desc({bias_tz}, data_t, format_any);
195-
auto output_md = memory::desc({output_tz}, data_t, format_any);
196-
197-
// need to re-create conv_forward_pd to feed conv_backward_data_pd
198-
std::shared_ptr<convolution_forward::desc> conv_forward_desc;
199-
if (bias_defined) {
200-
conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
201-
convolution_direct, input_md, weight_md, bias_md, output_md,
202-
_stride, _padding, _padding, padding_kind::zero));
203-
} else {
204-
conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
205-
convolution_direct, input_md, weight_md, output_md,
206-
_stride, _padding, _padding, padding_kind::zero));
207-
}
208-
209-
std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
210-
conv_forward_pd.reset(new convolution_forward::primitive_desc(
211-
*conv_forward_desc, cpu_engine));
212-
213-
std::shared_ptr<convolution_backward_data::desc> conv_backward_data_desc;
214-
conv_backward_data_desc.reset(new convolution_backward_data::desc(
215-
convolution_direct, input_md, weight_md, output_md,
216-
_stride, _padding, _padding, padding_kind::zero));
217-
218-
std::shared_ptr<convolution_backward_data::primitive_desc> conv_backward_data_pd;
219-
conv_backward_data_pd.reset(new convolution_backward_data::primitive_desc(
220-
*conv_backward_data_desc, cpu_engine, *conv_forward_pd));
221-
222-
auto grad_output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
223-
grad_output.data_ptr());
224-
auto weight_usr_memory = memory({{{weight_tz}, data_t, format_weight}, cpu_engine},
225-
weight.data_ptr());
226-
auto grad_input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
227-
grad_input.data_ptr());
228-
229-
std::vector<primitive> net;
230-
231-
auto grad_output_pd = conv_backward_data_pd->diff_dst_primitive_desc();
232-
auto grad_output_memory = grad_output_usr_memory;
233-
if (grad_output_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_output_pd)) {
234-
grad_output_memory = memory(grad_output_pd);
235-
net.push_back(reorder(grad_output_usr_memory, grad_output_memory));
236-
}
237-
238-
auto weight_pd = conv_backward_data_pd->weights_primitive_desc();
239-
auto weight_memory = weight_usr_memory;
240-
if (weight_usr_memory.get_primitive_desc() != memory::primitive_desc(weight_pd)) {
241-
weight_memory = memory(weight_pd);
242-
net.push_back(reorder(weight_usr_memory, weight_memory));
243-
}
244-
245-
auto grad_input_pd = conv_backward_data_pd->diff_src_primitive_desc();
246-
auto grad_input_memory = grad_input_usr_memory;
247-
if (grad_input_memory.get_primitive_desc() != memory::primitive_desc(grad_input_pd)) {
248-
grad_input_memory = memory(grad_input_pd);
249-
}
250-
251-
std::shared_ptr<convolution_backward_data> conv_backward_data;
252-
conv_backward_data.reset(new convolution_backward_data(*conv_backward_data_pd,
253-
grad_output_memory, weight_memory, grad_input_memory));
254-
net.push_back(*conv_backward_data);
255-
256-
if (grad_input_memory != grad_input_usr_memory) {
257-
net.push_back(reorder(grad_input_memory, grad_input_usr_memory));
258-
}
219+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) {
220+
const ideep::tensor mkldnn_grad_output = get_mkldnn_tensor(grad_output);
221+
const ideep::tensor mkldnn_weight = get_mkldnn_tensor(weight);
259222

260-
Stream::Instance().get_stream().submit(net);
223+
ideep::tensor mkldnn_grad_input = _mkldnn_conv2d_backward_input(
224+
input_size,
225+
mkldnn_grad_output,
226+
mkldnn_weight,
227+
padding,
228+
stride,
229+
dilation,
230+
groups);
261231

262-
return grad_input;
232+
if (grad_output.is_mkldnn()) {
233+
return new_with_itensor_mkldnn(std::move(mkldnn_grad_input), grad_output.options());
234+
} else {
235+
return mkldnn_to_dense(
236+
new_with_itensor_mkldnn(std::move(mkldnn_grad_input), grad_output.options()));
237+
}
263238
}
264239

265240
std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
266241
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
267-
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
268-
{
269-
auto grad_weight = at::empty(weight_size, grad_output.options());
270-
271-
Tensor grad_bias;
272-
if (bias_defined) {
273-
grad_bias = at::empty({grad_output.size(1)}, grad_output.options());
274-
}
275-
276-
auto cpu_engine = CpuEngine::Instance().get_engine();
277-
278-
int32_t g = groups;
279-
280-
int32_t n = input.size(0);
281-
int32_t ic = input.size(1);
282-
int32_t ih = input.size(2);
283-
int32_t iw = input.size(3);
284-
285-
int32_t oc = grad_output.size(1);
286-
int32_t oh = grad_output.size(2);
287-
int32_t ow = grad_output.size(3);
288-
289-
int32_t kh = grad_weight.size(2);
290-
int32_t kw = grad_weight.size(3);
291-
292-
int32_t sh = stride[0];
293-
int32_t sw = stride[1];
294-
int32_t ph = padding[0];
295-
int32_t pw = padding[1];
296-
297-
auto data_t = memory::data_type::f32;
298-
auto format_any = memory::format::any;
299-
auto format_nchw = memory::format::nchw;
300-
auto format_weight = (g!= 1) ? memory::format::goihw : memory::format::oihw;
301-
auto format_x = memory::format::x;
302-
303-
memory::dims input_tz = {n, ic, ih, iw};
304-
memory::dims weight_tz = (g!= 1) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
305-
memory::dims bias_tz = {oc};
306-
memory::dims output_tz = {n, oc, oh, ow};
307-
memory::dims _stride = {sh, sw};
308-
memory::dims _padding = {ph, pw};
309-
310-
memory::desc input_md({input_tz}, data_t, format_any);
311-
memory::desc weight_md({weight_tz}, data_t, format_any);
312-
memory::desc bias_md({bias_tz}, data_t, format_any);
313-
memory::desc output_md({output_tz}, data_t, format_any);
314-
315-
// need to re-create conv_forward_pd to feed conv_backward_weight_pd
316-
std::shared_ptr<convolution_forward::desc> conv_forward_desc;
317-
if (bias_defined) {
318-
conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
319-
convolution_direct, input_md, weight_md, bias_md, output_md,
320-
_stride, _padding, _padding, padding_kind::zero));
321-
} else {
322-
conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
323-
convolution_direct, input_md, weight_md, output_md,
324-
_stride, _padding, _padding, padding_kind::zero));
325-
}
326-
327-
std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
328-
conv_forward_pd.reset(new convolution_forward::primitive_desc(
329-
*conv_forward_desc, cpu_engine));
330-
331-
std::shared_ptr<convolution_backward_weights::desc> conv_backward_weight_desc;
332-
if (bias_defined) {
333-
conv_backward_weight_desc.reset(new convolution_backward_weights::desc(
334-
convolution_direct, input_md, weight_md, bias_md, output_md,
335-
_stride, _padding, _padding, padding_kind::zero));
336-
} else {
337-
conv_backward_weight_desc.reset(new convolution_backward_weights::desc(
338-
convolution_direct, input_md, weight_md, output_md,
339-
_stride, _padding, _padding, padding_kind::zero));
340-
}
341-
342-
std::shared_ptr<convolution_backward_weights::primitive_desc> conv_backward_weight_pd;
343-
conv_backward_weight_pd.reset(new convolution_backward_weights::primitive_desc(
344-
*conv_backward_weight_desc, cpu_engine, *conv_forward_pd));
345-
346-
auto input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
347-
input.data_ptr());
348-
auto grad_output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
349-
grad_output.data_ptr());
350-
auto grad_weight_usr_memory = memory({{{weight_tz}, data_t, format_weight}, cpu_engine},
351-
grad_weight.data_ptr());
352-
std::shared_ptr<memory> grad_bias_memory;
353-
354-
std::vector<primitive> net;
355-
356-
auto input_pd = conv_backward_weight_pd->src_primitive_desc();
357-
auto input_memory = input_usr_memory;
358-
if (input_usr_memory.get_primitive_desc() != memory::primitive_desc(input_pd)) {
359-
input_memory = memory(input_pd);
360-
net.push_back(reorder(input_usr_memory, input_memory));
361-
}
362-
363-
auto grad_output_pd = conv_backward_weight_pd->diff_dst_primitive_desc();
364-
auto grad_output_memory = grad_output_usr_memory;
365-
if (grad_output_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_output_pd)) {
366-
grad_output_memory = memory(grad_output_pd);
367-
net.push_back(reorder(grad_output_usr_memory, grad_output_memory));
368-
}
242+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
243+
const ideep::tensor mkldnn_grad_output = get_mkldnn_tensor(grad_output);
244+
const ideep::tensor mkldnn_input = get_mkldnn_tensor(input);
369245

370-
auto grad_weight_pd = conv_backward_weight_pd->diff_weights_primitive_desc();
371-
auto grad_weight_memory = grad_weight_usr_memory;
372-
if (grad_weight_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_weight_pd)) {
373-
grad_weight_memory = memory(grad_weight_pd);
374-
}
246+
ideep::tensor mkldnn_grad_weight, mkldnn_grad_bias;
247+
std::tie(mkldnn_grad_weight, mkldnn_grad_bias) =_mkldnn_conv2d_backward_weights(
248+
weight_size,
249+
mkldnn_grad_output,
250+
mkldnn_input,
251+
padding,
252+
stride,
253+
dilation,
254+
groups,
255+
bias_defined);
375256

376-
std::shared_ptr<convolution_backward_weights> conv_backward_weight;
377-
if (bias_defined) {
378-
grad_bias_memory.reset(new memory({{{bias_tz}, data_t, format_x}, cpu_engine},
379-
grad_bias.data_ptr()));
380-
conv_backward_weight.reset(new convolution_backward_weights(*conv_backward_weight_pd,
381-
input_memory, grad_output_memory, grad_weight_memory, *grad_bias_memory));
257+
if (grad_output.is_mkldnn()) {
258+
return std::tuple<at::Tensor, at::Tensor>{
259+
new_with_itensor_mkldnn(std::move(mkldnn_grad_weight), grad_output.options()),
260+
new_with_itensor_mkldnn(std::move(mkldnn_grad_bias), grad_output.options())};
382261
} else {
383-
conv_backward_weight.reset(new convolution_backward_weights(*conv_backward_weight_pd,
384-
input_memory, grad_output_memory, grad_weight_memory));
385-
}
386-
387-
net.push_back(*conv_backward_weight);
388-
389-
if (grad_weight_memory != grad_weight_usr_memory) {
390-
net.push_back(reorder(grad_weight_memory, grad_weight_usr_memory));
262+
return std::tuple<at::Tensor, at::Tensor>{
263+
mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_weight), grad_output.options())),
264+
mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_bias), grad_output.options()))};
391265
}
392-
393-
Stream::Instance().get_stream().submit(net);
394-
395-
return std::tuple<at::Tensor, at::Tensor>{grad_weight, grad_bias};
396266
}
397267

398268
std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward(
399269
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
400270
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask)
401271
{
402-
Tensor grad_output = grad_output_t.contiguous();
272+
Tensor grad_output = grad_output_t.is_mkldnn() ? grad_output_t : grad_output_t.contiguous();
403273

404274
Tensor grad_input, grad_weight, grad_bias;
405275
if (output_mask[0]) {
406276
grad_input = at::mkldnn_convolution_backward_input(
407-
input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]);
277+
input.sizes(), grad_output, weight, padding, stride, dilation, groups);
408278
}
409279
if (output_mask[1] || output_mask[2]) {
410280
std::tie(grad_weight, grad_bias) = at::mkldnn_convolution_backward_weights(

aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor) {
1515
Tensor cpu_tensor = at::empty(
1616
std::vector<int64_t>(dims.begin(), dims.end()),
1717
mkldnn_tensor.options().layout(c10::kStrided));
18-
stensor.reorder_to(cpu_tensor.template data<float>());
18+
if (!stensor.is_empty()) {
19+
stensor.reorder_to(cpu_tensor.template data<float>());
20+
}
1921
return cpu_tensor;
2022
}
2123

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,7 @@
12091209

12101210
- func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
12111211

1212-
- func: mkldnn_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool bias_defined) -> Tensor
1212+
- func: mkldnn_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
12131213

12141214
- func: mkldnn_convolution_backward_weights(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool bias_defined) -> (Tensor, Tensor)
12151215

0 commit comments

Comments
 (0)