@@ -14,7 +14,7 @@ at::Tensor mkldnn_convolution(
1414
1515at::Tensor mkldnn_convolution_backward_input (
1616 IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
17- IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined ) {
17+ IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) {
1818 AT_ERROR (" mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support" );
1919}
2020
@@ -118,6 +118,70 @@ ideep::tensor _mkldnn_conv2d(
118118 return y;
119119}
120120
121+ ideep::tensor _mkldnn_conv2d_backward_input (
122+ at::IntArrayRef input_sizes,
123+ const ideep::tensor& grady,
124+ const ideep::tensor& w,
125+ at::IntArrayRef padding,
126+ at::IntArrayRef stride,
127+ at::IntArrayRef dilation,
128+ int64_t groups) {
129+ ideep::tensor gradx;
130+ ideep::convolution_backward_data::compute<AllocForMKLDNN>(
131+ grady,
132+ w,
133+ {input_sizes.cbegin (), input_sizes.cend ()},
134+ gradx,
135+ {stride.begin (), stride.end ()},
136+ {dilation.begin (), dilation.end ()},
137+ {padding.begin (), padding.end ()},
138+ {padding.begin (), padding.end ()},
139+ groups,
140+ ideep::algorithm::convolution_direct);
141+
142+ return gradx;
143+ }
144+
145+ std::tuple<ideep::tensor, ideep::tensor> _mkldnn_conv2d_backward_weights (
146+ at::IntArrayRef weight_sizes,
147+ const ideep::tensor& grady,
148+ const ideep::tensor& x,
149+ at::IntArrayRef padding,
150+ at::IntArrayRef stride,
151+ at::IntArrayRef dilation,
152+ int64_t groups,
153+ bool bias_defined) {
154+ ideep::tensor gradw, gradb;
155+ if (bias_defined) {
156+ ideep::convolution_backward_weights::compute<AllocForMKLDNN>(
157+ x,
158+ grady,
159+ {weight_sizes.cbegin (), weight_sizes.cend ()},
160+ gradw,
161+ gradb,
162+ {stride.begin (), stride.end ()},
163+ {dilation.begin (), dilation.end ()},
164+ {padding.begin (), padding.end ()},
165+ {padding.begin (), padding.end ()},
166+ groups,
167+ ideep::algorithm::convolution_direct);
168+ } else {
169+ ideep::convolution_backward_weights::compute<AllocForMKLDNN>(
170+ x,
171+ grady,
172+ {weight_sizes.cbegin (), weight_sizes.cend ()},
173+ gradw,
174+ {stride.begin (), stride.end ()},
175+ {dilation.begin (), dilation.end ()},
176+ {padding.begin (), padding.end ()},
177+ {padding.begin (), padding.end ()},
178+ groups,
179+ ideep::algorithm::convolution_direct);
180+ }
181+
182+ return std::tuple<ideep::tensor, ideep::tensor>{gradw, gradb};
183+ }
184+
121185at::Tensor mkldnn_convolution (
122186 const at::Tensor& input,
123187 const at::Tensor& weight,
@@ -152,259 +216,65 @@ at::Tensor mkldnn_convolution(
152216
153217Tensor mkldnn_convolution_backward_input (
154218 IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
155- IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
156- {
157- auto grad_input = at::empty (input_size, grad_output.options ());
158-
159- auto cpu_engine = CpuEngine::Instance ().get_engine ();
160-
161- int32_t g = groups;
162-
163- int32_t n = grad_input.size (0 );
164- int32_t ic = grad_input.size (1 );
165- int32_t ih = grad_input.size (2 );
166- int32_t iw = grad_input.size (3 );
167-
168- int32_t oc = grad_output.size (1 );
169- int32_t oh = grad_output.size (2 );
170- int32_t ow = grad_output.size (3 );
171-
172- int32_t kh = weight.size (2 );
173- int32_t kw = weight.size (3 );
174-
175- int32_t sh = stride[0 ];
176- int32_t sw = stride[1 ];
177- int32_t ph = padding[0 ];
178- int32_t pw = padding[1 ];
179-
180- auto data_t = memory::data_type::f32 ;
181- auto format_any = memory::format::any;
182- auto format_nchw = memory::format::nchw;
183- auto format_weight = (g!= 1 ) ? memory::format::goihw : memory::format::oihw;
184-
185- memory::dims input_tz = {n, ic, ih, iw};
186- memory::dims weight_tz = (g!= 1 ) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
187- memory::dims bias_tz = {oc};
188- memory::dims output_tz = {n, oc, oh, ow};
189- memory::dims _stride = {sh, sw};
190- memory::dims _padding = {ph, pw};
191-
192- auto input_md = memory::desc ({input_tz}, data_t , format_any);
193- auto weight_md = memory::desc ({weight_tz}, data_t , format_any);
194- auto bias_md = memory::desc ({bias_tz}, data_t , format_any);
195- auto output_md = memory::desc ({output_tz}, data_t , format_any);
196-
197- // need to re-create conv_forward_pd to feed conv_backward_data_pd
198- std::shared_ptr<convolution_forward::desc> conv_forward_desc;
199- if (bias_defined) {
200- conv_forward_desc.reset (new convolution_forward::desc (prop_kind::forward,
201- convolution_direct, input_md, weight_md, bias_md, output_md,
202- _stride, _padding, _padding, padding_kind::zero));
203- } else {
204- conv_forward_desc.reset (new convolution_forward::desc (prop_kind::forward,
205- convolution_direct, input_md, weight_md, output_md,
206- _stride, _padding, _padding, padding_kind::zero));
207- }
208-
209- std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
210- conv_forward_pd.reset (new convolution_forward::primitive_desc (
211- *conv_forward_desc, cpu_engine));
212-
213- std::shared_ptr<convolution_backward_data::desc> conv_backward_data_desc;
214- conv_backward_data_desc.reset (new convolution_backward_data::desc (
215- convolution_direct, input_md, weight_md, output_md,
216- _stride, _padding, _padding, padding_kind::zero));
217-
218- std::shared_ptr<convolution_backward_data::primitive_desc> conv_backward_data_pd;
219- conv_backward_data_pd.reset (new convolution_backward_data::primitive_desc (
220- *conv_backward_data_desc, cpu_engine, *conv_forward_pd));
221-
222- auto grad_output_usr_memory = memory ({{{output_tz}, data_t , format_nchw}, cpu_engine},
223- grad_output.data_ptr ());
224- auto weight_usr_memory = memory ({{{weight_tz}, data_t , format_weight}, cpu_engine},
225- weight.data_ptr ());
226- auto grad_input_usr_memory = memory ({{{input_tz}, data_t , format_nchw}, cpu_engine},
227- grad_input.data_ptr ());
228-
229- std::vector<primitive> net;
230-
231- auto grad_output_pd = conv_backward_data_pd->diff_dst_primitive_desc ();
232- auto grad_output_memory = grad_output_usr_memory;
233- if (grad_output_usr_memory.get_primitive_desc () != memory::primitive_desc (grad_output_pd)) {
234- grad_output_memory = memory (grad_output_pd);
235- net.push_back (reorder (grad_output_usr_memory, grad_output_memory));
236- }
237-
238- auto weight_pd = conv_backward_data_pd->weights_primitive_desc ();
239- auto weight_memory = weight_usr_memory;
240- if (weight_usr_memory.get_primitive_desc () != memory::primitive_desc (weight_pd)) {
241- weight_memory = memory (weight_pd);
242- net.push_back (reorder (weight_usr_memory, weight_memory));
243- }
244-
245- auto grad_input_pd = conv_backward_data_pd->diff_src_primitive_desc ();
246- auto grad_input_memory = grad_input_usr_memory;
247- if (grad_input_memory.get_primitive_desc () != memory::primitive_desc (grad_input_pd)) {
248- grad_input_memory = memory (grad_input_pd);
249- }
250-
251- std::shared_ptr<convolution_backward_data> conv_backward_data;
252- conv_backward_data.reset (new convolution_backward_data (*conv_backward_data_pd,
253- grad_output_memory, weight_memory, grad_input_memory));
254- net.push_back (*conv_backward_data);
255-
256- if (grad_input_memory != grad_input_usr_memory) {
257- net.push_back (reorder (grad_input_memory, grad_input_usr_memory));
258- }
219+ IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) {
220+ const ideep::tensor mkldnn_grad_output = get_mkldnn_tensor (grad_output);
221+ const ideep::tensor mkldnn_weight = get_mkldnn_tensor (weight);
259222
260- Stream::Instance ().get_stream ().submit (net);
223+ ideep::tensor mkldnn_grad_input = _mkldnn_conv2d_backward_input (
224+ input_size,
225+ mkldnn_grad_output,
226+ mkldnn_weight,
227+ padding,
228+ stride,
229+ dilation,
230+ groups);
261231
262- return grad_input;
232+ if (grad_output.is_mkldnn ()) {
233+ return new_with_itensor_mkldnn (std::move (mkldnn_grad_input), grad_output.options ());
234+ } else {
235+ return mkldnn_to_dense (
236+ new_with_itensor_mkldnn (std::move (mkldnn_grad_input), grad_output.options ()));
237+ }
263238}
264239
265240std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights (
266241 IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
267- IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
268- {
269- auto grad_weight = at::empty (weight_size, grad_output.options ());
270-
271- Tensor grad_bias;
272- if (bias_defined) {
273- grad_bias = at::empty ({grad_output.size (1 )}, grad_output.options ());
274- }
275-
276- auto cpu_engine = CpuEngine::Instance ().get_engine ();
277-
278- int32_t g = groups;
279-
280- int32_t n = input.size (0 );
281- int32_t ic = input.size (1 );
282- int32_t ih = input.size (2 );
283- int32_t iw = input.size (3 );
284-
285- int32_t oc = grad_output.size (1 );
286- int32_t oh = grad_output.size (2 );
287- int32_t ow = grad_output.size (3 );
288-
289- int32_t kh = grad_weight.size (2 );
290- int32_t kw = grad_weight.size (3 );
291-
292- int32_t sh = stride[0 ];
293- int32_t sw = stride[1 ];
294- int32_t ph = padding[0 ];
295- int32_t pw = padding[1 ];
296-
297- auto data_t = memory::data_type::f32 ;
298- auto format_any = memory::format::any;
299- auto format_nchw = memory::format::nchw;
300- auto format_weight = (g!= 1 ) ? memory::format::goihw : memory::format::oihw;
301- auto format_x = memory::format::x;
302-
303- memory::dims input_tz = {n, ic, ih, iw};
304- memory::dims weight_tz = (g!= 1 ) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
305- memory::dims bias_tz = {oc};
306- memory::dims output_tz = {n, oc, oh, ow};
307- memory::dims _stride = {sh, sw};
308- memory::dims _padding = {ph, pw};
309-
310- memory::desc input_md ({input_tz}, data_t , format_any);
311- memory::desc weight_md ({weight_tz}, data_t , format_any);
312- memory::desc bias_md ({bias_tz}, data_t , format_any);
313- memory::desc output_md ({output_tz}, data_t , format_any);
314-
315- // need to re-create conv_forward_pd to feed conv_backward_weight_pd
316- std::shared_ptr<convolution_forward::desc> conv_forward_desc;
317- if (bias_defined) {
318- conv_forward_desc.reset (new convolution_forward::desc (prop_kind::forward,
319- convolution_direct, input_md, weight_md, bias_md, output_md,
320- _stride, _padding, _padding, padding_kind::zero));
321- } else {
322- conv_forward_desc.reset (new convolution_forward::desc (prop_kind::forward,
323- convolution_direct, input_md, weight_md, output_md,
324- _stride, _padding, _padding, padding_kind::zero));
325- }
326-
327- std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
328- conv_forward_pd.reset (new convolution_forward::primitive_desc (
329- *conv_forward_desc, cpu_engine));
330-
331- std::shared_ptr<convolution_backward_weights::desc> conv_backward_weight_desc;
332- if (bias_defined) {
333- conv_backward_weight_desc.reset (new convolution_backward_weights::desc (
334- convolution_direct, input_md, weight_md, bias_md, output_md,
335- _stride, _padding, _padding, padding_kind::zero));
336- } else {
337- conv_backward_weight_desc.reset (new convolution_backward_weights::desc (
338- convolution_direct, input_md, weight_md, output_md,
339- _stride, _padding, _padding, padding_kind::zero));
340- }
341-
342- std::shared_ptr<convolution_backward_weights::primitive_desc> conv_backward_weight_pd;
343- conv_backward_weight_pd.reset (new convolution_backward_weights::primitive_desc (
344- *conv_backward_weight_desc, cpu_engine, *conv_forward_pd));
345-
346- auto input_usr_memory = memory ({{{input_tz}, data_t , format_nchw}, cpu_engine},
347- input.data_ptr ());
348- auto grad_output_usr_memory = memory ({{{output_tz}, data_t , format_nchw}, cpu_engine},
349- grad_output.data_ptr ());
350- auto grad_weight_usr_memory = memory ({{{weight_tz}, data_t , format_weight}, cpu_engine},
351- grad_weight.data_ptr ());
352- std::shared_ptr<memory> grad_bias_memory;
353-
354- std::vector<primitive> net;
355-
356- auto input_pd = conv_backward_weight_pd->src_primitive_desc ();
357- auto input_memory = input_usr_memory;
358- if (input_usr_memory.get_primitive_desc () != memory::primitive_desc (input_pd)) {
359- input_memory = memory (input_pd);
360- net.push_back (reorder (input_usr_memory, input_memory));
361- }
362-
363- auto grad_output_pd = conv_backward_weight_pd->diff_dst_primitive_desc ();
364- auto grad_output_memory = grad_output_usr_memory;
365- if (grad_output_usr_memory.get_primitive_desc () != memory::primitive_desc (grad_output_pd)) {
366- grad_output_memory = memory (grad_output_pd);
367- net.push_back (reorder (grad_output_usr_memory, grad_output_memory));
368- }
242+ IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
243+ const ideep::tensor mkldnn_grad_output = get_mkldnn_tensor (grad_output);
244+ const ideep::tensor mkldnn_input = get_mkldnn_tensor (input);
369245
370- auto grad_weight_pd = conv_backward_weight_pd->diff_weights_primitive_desc ();
371- auto grad_weight_memory = grad_weight_usr_memory;
372- if (grad_weight_usr_memory.get_primitive_desc () != memory::primitive_desc (grad_weight_pd)) {
373- grad_weight_memory = memory (grad_weight_pd);
374- }
246+ ideep::tensor mkldnn_grad_weight, mkldnn_grad_bias;
247+ std::tie (mkldnn_grad_weight, mkldnn_grad_bias) =_mkldnn_conv2d_backward_weights (
248+ weight_size,
249+ mkldnn_grad_output,
250+ mkldnn_input,
251+ padding,
252+ stride,
253+ dilation,
254+ groups,
255+ bias_defined);
375256
376- std::shared_ptr<convolution_backward_weights> conv_backward_weight;
377- if (bias_defined) {
378- grad_bias_memory.reset (new memory ({{{bias_tz}, data_t , format_x}, cpu_engine},
379- grad_bias.data_ptr ()));
380- conv_backward_weight.reset (new convolution_backward_weights (*conv_backward_weight_pd,
381- input_memory, grad_output_memory, grad_weight_memory, *grad_bias_memory));
257+ if (grad_output.is_mkldnn ()) {
258+ return std::tuple<at::Tensor, at::Tensor>{
259+ new_with_itensor_mkldnn (std::move (mkldnn_grad_weight), grad_output.options ()),
260+ new_with_itensor_mkldnn (std::move (mkldnn_grad_bias), grad_output.options ())};
382261 } else {
383- conv_backward_weight.reset (new convolution_backward_weights (*conv_backward_weight_pd,
384- input_memory, grad_output_memory, grad_weight_memory));
385- }
386-
387- net.push_back (*conv_backward_weight);
388-
389- if (grad_weight_memory != grad_weight_usr_memory) {
390- net.push_back (reorder (grad_weight_memory, grad_weight_usr_memory));
262+ return std::tuple<at::Tensor, at::Tensor>{
263+ mkldnn_to_dense (new_with_itensor_mkldnn (std::move (mkldnn_grad_weight), grad_output.options ())),
264+ mkldnn_to_dense (new_with_itensor_mkldnn (std::move (mkldnn_grad_bias), grad_output.options ()))};
391265 }
392-
393- Stream::Instance ().get_stream ().submit (net);
394-
395- return std::tuple<at::Tensor, at::Tensor>{grad_weight, grad_bias};
396266}
397267
398268std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward (
399269 const at::Tensor& input, const at::Tensor& grad_output_t , const at::Tensor& weight,
400270 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array<bool ,3 > output_mask)
401271{
402- Tensor grad_output = grad_output_t .contiguous ();
272+ Tensor grad_output = grad_output_t .is_mkldnn () ? grad_output_t : grad_output_t . contiguous ();
403273
404274 Tensor grad_input, grad_weight, grad_bias;
405275 if (output_mask[0 ]) {
406276 grad_input = at::mkldnn_convolution_backward_input (
407- input.sizes (), grad_output, weight, padding, stride, dilation, groups, output_mask[ 2 ] );
277+ input.sizes (), grad_output, weight, padding, stride, dilation, groups);
408278 }
409279 if (output_mask[1 ] || output_mask[2 ]) {
410280 std::tie (grad_weight, grad_bias) = at::mkldnn_convolution_backward_weights (
0 commit comments