@@ -8,25 +8,25 @@ namespace at { namespace native {
88
99at::Tensor mkldnn_convolution (
1010 const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias,
11- IntList padding, IntList stride, IntList dilation) {
11+ IntList padding, IntList stride, IntList dilation, int64_t groups ) {
1212 throw std::runtime_error (" mkldnn_convolution_forward: ATen not compiled with MKLDNN support" );
1313}
1414
1515at::Tensor mkldnn_convolution_backward_input (
1616 IntList input_size, const at::Tensor& grad_output, const at::Tensor& weight,
17- IntList padding, IntList stride, IntList dilation, bool bias_defined) {
17+ IntList padding, IntList stride, IntList dilation, int64_t groups, bool bias_defined) {
1818 throw std::runtime_error (" mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support" );
1919}
2020
2121std::tuple<at::Tensor,at::Tensor> mkldnn_convolution_backward_weights (
2222 IntList weight_size, const at::Tensor& grad_output, const at::Tensor& input,
23- IntList padding, IntList stride, IntList dilation, bool bias_defined) {
23+ IntList padding, IntList stride, IntList dilation, int64_t groups, bool bias_defined) {
2424 throw std::runtime_error (" mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support" );
2525}
2626
2727std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward (
2828 const at::Tensor& input, const at::Tensor& grad_output_t , const at::Tensor& weight,
29- IntList padding, IntList stride, IntList dilation, std::array<bool ,3 > output_mask) {
29+ IntList padding, IntList stride, IntList dilation, int64_t groups, std::array<bool ,3 > output_mask) {
3030 throw std::runtime_error (" mkldnn_convolution_backward: ATen not compiled with MKLDNN support" );
3131}
3232
@@ -52,7 +52,7 @@ constexpr int max_dim = 3;
5252
5353std::vector<int64_t > conv_output_size (
5454 IntList input_size, IntList weight_size,
55- IntList padding, IntList stride, IntList dilation)
55+ IntList padding, IntList stride, IntList dilation, int64_t groups )
5656{
5757 auto dim = input_size.size ();
5858 std::vector<int64_t > output_size (dim);
@@ -68,12 +68,14 @@ std::vector<int64_t> conv_output_size(
6868
6969at::Tensor mkldnn_convolution (
7070 const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias,
71- IntList padding, IntList stride, IntList dilation)
71+ IntList padding, IntList stride, IntList dilation, int64_t groups )
7272{
7373 auto output = input.type ().tensor (conv_output_size (
74- input.sizes (), weight.sizes (), padding, stride, dilation));
74+ input.sizes (), weight.sizes (), padding, stride, dilation, groups ));
7575
7676 auto cpu_engine = CpuEngine::Instance ().get_engine ();
77+
78+ int32_t g = groups;
7779
7880 int32_t n = input.size (0 );
7981 int32_t ic = input.size (1 );
@@ -95,11 +97,11 @@ at::Tensor mkldnn_convolution(
9597 auto data_t = memory::data_type::f32 ;
9698 auto format_any = memory::format::any;
9799 auto format_nchw = memory::format::nchw;
98- auto format_oihw = memory::format::oihw;
100+ auto format_weight = (g!= 1 ) ? memory::format::goihw : memory::format::oihw;
99101 auto format_x = memory::format::x;
100102
101103 memory::dims input_tz = {n, ic, ih, iw};
102- memory::dims weight_tz = {oc, ic, kh, kw};
104+ memory::dims weight_tz = (g!= 1 ) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims {oc, ic, kh, kw};
103105 memory::dims bias_tz = {oc};
104106 memory::dims output_tz = {n, oc, oh, ow};
105107 memory::dims _stride = {sh, sw};
@@ -127,7 +129,7 @@ at::Tensor mkldnn_convolution(
127129
128130 auto input_usr_memory = memory ({{{input_tz}, data_t , format_nchw}, cpu_engine},
129131 input.data_ptr ());
130- auto weight_usr_memory = memory ({{{weight_tz}, data_t , format_oihw }, cpu_engine},
132+ auto weight_usr_memory = memory ({{{weight_tz}, data_t , format_weight }, cpu_engine},
131133 weight.data_ptr ());
132134 auto output_usr_memory = memory ({{{output_tz}, data_t , format_nchw}, cpu_engine},
133135 output.data_ptr ());
@@ -178,12 +180,14 @@ at::Tensor mkldnn_convolution(
178180
179181Tensor mkldnn_convolution_backward_input (
180182 IntList input_size, const at::Tensor& grad_output, const at::Tensor& weight,
181- IntList padding, IntList stride, IntList dilation, bool bias_defined)
183+ IntList padding, IntList stride, IntList dilation, int64_t groups, bool bias_defined)
182184{
183185 auto grad_input = grad_output.type ().tensor (input_size);
184186
185187 auto cpu_engine = CpuEngine::Instance ().get_engine ();
186188
189+ int32_t g = groups;
190+
187191 int32_t n = grad_input.size (0 );
188192 int32_t ic = grad_input.size (1 );
189193 int32_t ih = grad_input.size (2 );
@@ -204,10 +208,10 @@ Tensor mkldnn_convolution_backward_input(
204208 auto data_t = memory::data_type::f32 ;
205209 auto format_any = memory::format::any;
206210 auto format_nchw = memory::format::nchw;
207- auto format_oihw = memory::format::oihw;
211+ auto format_weight = (g!= 1 ) ? memory::format::goihw : memory::format::oihw;
208212
209213 memory::dims input_tz = {n, ic, ih, iw};
210- memory::dims weight_tz = {oc, ic, kh, kw};
214+ memory::dims weight_tz = (g!= 1 ) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims {oc, ic, kh, kw};
211215 memory::dims bias_tz = {oc};
212216 memory::dims output_tz = {n, oc, oh, ow};
213217 memory::dims _stride = {sh, sw};
@@ -245,7 +249,7 @@ Tensor mkldnn_convolution_backward_input(
245249
246250 auto grad_output_usr_memory = memory ({{{output_tz}, data_t , format_nchw}, cpu_engine},
247251 grad_output.data_ptr ());
248- auto weight_usr_memory = memory ({{{weight_tz}, data_t , format_oihw }, cpu_engine},
252+ auto weight_usr_memory = memory ({{{weight_tz}, data_t , format_weight }, cpu_engine},
249253 weight.data_ptr ());
250254 auto grad_input_usr_memory = memory ({{{input_tz}, data_t , format_nchw}, cpu_engine},
251255 grad_input.data_ptr ());
@@ -288,7 +292,7 @@ Tensor mkldnn_convolution_backward_input(
288292
289293std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights (
290294 IntList weight_size, const at::Tensor& grad_output, const at::Tensor& input,
291- IntList padding, IntList stride, IntList dilation, bool bias_defined)
295+ IntList padding, IntList stride, IntList dilation, int64_t groups, bool bias_defined)
292296{
293297 auto grad_weight = grad_output.type ().tensor (weight_size);
294298
@@ -299,6 +303,8 @@ std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
299303
300304 auto cpu_engine = CpuEngine::Instance ().get_engine ();
301305
306+ int32_t g = groups;
307+
302308 int32_t n = input.size (0 );
303309 int32_t ic = input.size (1 );
304310 int32_t ih = input.size (2 );
@@ -319,11 +325,11 @@ std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
319325 auto data_t = memory::data_type::f32 ;
320326 auto format_any = memory::format::any;
321327 auto format_nchw = memory::format::nchw;
322- auto format_oihw = memory::format::oihw;
328+ auto format_weight = (g!= 1 ) ? memory::format::goihw : memory::format::oihw;
323329 auto format_x = memory::format::x;
324330
325331 memory::dims input_tz = {n, ic, ih, iw};
326- memory::dims weight_tz = {oc, ic, kh, kw};
332+ memory::dims weight_tz = (g!= 1 ) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims {oc, ic, kh, kw};
327333 memory::dims bias_tz = {oc};
328334 memory::dims output_tz = {n, oc, oh, ow};
329335 memory::dims _stride = {sh, sw};
@@ -369,7 +375,7 @@ std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
369375 input.data_ptr ());
370376 auto grad_output_usr_memory = memory ({{{output_tz}, data_t , format_nchw}, cpu_engine},
371377 grad_output.data_ptr ());
372- auto grad_weight_usr_memory = memory ({{{weight_tz}, data_t , format_oihw }, cpu_engine},
378+ auto grad_weight_usr_memory = memory ({{{weight_tz}, data_t , format_weight }, cpu_engine},
373379 grad_weight.data_ptr ());
374380 std::shared_ptr<memory> grad_bias_memory;
375381
@@ -419,18 +425,18 @@ std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
419425
420426std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward (
421427 const at::Tensor& input, const at::Tensor& grad_output_t , const at::Tensor& weight,
422- IntList padding, IntList stride, IntList dilation, std::array<bool ,3 > output_mask)
428+ IntList padding, IntList stride, IntList dilation, int64_t groups, std::array<bool ,3 > output_mask)
423429{
424430 Tensor grad_output = grad_output_t .contiguous ();
425431
426432 Tensor grad_input, grad_weight, grad_bias;
427433 if (output_mask[0 ]) {
428434 grad_input = at::mkldnn_convolution_backward_input (
429- input.sizes (), grad_output, weight, padding, stride, dilation, output_mask[2 ]);
435+ input.sizes (), grad_output, weight, padding, stride, dilation, groups, output_mask[2 ]);
430436 }
431437 if (output_mask[1 ] || output_mask[2 ]) {
432438 std::tie (grad_weight, grad_bias) = at::mkldnn_convolution_backward_weights (
433- weight.sizes (), grad_output, input, padding, stride, dilation, output_mask[2 ]);
439+ weight.sizes (), grad_output, input, padding, stride, dilation, groups, output_mask[2 ]);
434440 }
435441
436442 return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
0 commit comments