Skip to content

Commit 8b4622f

Browse files
committed
cnn: add depthwise conv support for mkldnn
Change-Id: I3836dacc63afc1b5e31b1d706bba6bb13699ba41
1 parent 34554d6 commit 8b4622f

File tree

4 files changed

+35
-30
lines changed

4 files changed

+35
-30
lines changed

aten/src/ATen/native/Convolution.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,7 @@ auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool {
124124
input.type().scalarType() == kFloat && // only on CPU Float Tensors
125125
!is_dilated() && // doesn't support dilation
126126
!transposed && // or transposed tensors
127-
input.ndimension() == 4 && // must be in NCHW format
128-
groups == 1;
127+
input.ndimension() == 4; // must be in NCHW format
129128
#endif
130129
return false;
131130
}
@@ -369,7 +368,7 @@ at::Tensor _convolution(
369368
throw std::runtime_error(ss.str());
370369
}
371370

372-
output = at::mkldnn_convolution(input, weight, bias, params.padding, params.stride, params.dilation);
371+
output = at::mkldnn_convolution(input, weight, bias, params.padding, params.stride, params.dilation, params.groups);
373372
#endif
374373
} else {
375374
if (params.groups == 1) {

aten/src/ATen/native/mkldnn/Conv.cpp

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,25 @@ namespace at { namespace native {
88

99
at::Tensor mkldnn_convolution(
1010
const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias,
11-
IntList padding, IntList stride, IntList dilation) {
11+
IntList padding, IntList stride, IntList dilation, int64_t groups) {
1212
throw std::runtime_error("mkldnn_convolution_forward: ATen not compiled with MKLDNN support");
1313
}
1414

1515
at::Tensor mkldnn_convolution_backward_input(
1616
IntList input_size, const at::Tensor& grad_output, const at::Tensor& weight,
17-
IntList padding, IntList stride, IntList dilation, bool bias_defined) {
17+
IntList padding, IntList stride, IntList dilation, int64_t groups, bool bias_defined) {
1818
throw std::runtime_error("mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support");
1919
}
2020

2121
std::tuple<at::Tensor,at::Tensor> mkldnn_convolution_backward_weights(
2222
IntList weight_size, const at::Tensor& grad_output, const at::Tensor& input,
23-
IntList padding, IntList stride, IntList dilation, bool bias_defined) {
23+
IntList padding, IntList stride, IntList dilation, int64_t groups, bool bias_defined) {
2424
throw std::runtime_error("mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support");
2525
}
2626

2727
std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward(
2828
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
29-
IntList padding, IntList stride, IntList dilation, std::array<bool,3> output_mask) {
29+
IntList padding, IntList stride, IntList dilation, int64_t groups, std::array<bool,3> output_mask) {
3030
throw std::runtime_error("mkldnn_convolution_backward: ATen not compiled with MKLDNN support");
3131
}
3232

@@ -52,7 +52,7 @@ constexpr int max_dim = 3;
5252

5353
std::vector<int64_t> conv_output_size(
5454
IntList input_size, IntList weight_size,
55-
IntList padding, IntList stride, IntList dilation)
55+
IntList padding, IntList stride, IntList dilation, int64_t groups)
5656
{
5757
auto dim = input_size.size();
5858
std::vector<int64_t> output_size(dim);
@@ -68,12 +68,14 @@ std::vector<int64_t> conv_output_size(
6868

6969
at::Tensor mkldnn_convolution(
7070
const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias,
71-
IntList padding, IntList stride, IntList dilation)
71+
IntList padding, IntList stride, IntList dilation, int64_t groups)
7272
{
7373
auto output = input.type().tensor(conv_output_size(
74-
input.sizes(), weight.sizes(), padding, stride, dilation));
74+
input.sizes(), weight.sizes(), padding, stride, dilation, groups));
7575

7676
auto cpu_engine = CpuEngine::Instance().get_engine();
77+
78+
int32_t g = groups;
7779

7880
int32_t n = input.size(0);
7981
int32_t ic = input.size(1);
@@ -95,11 +97,11 @@ at::Tensor mkldnn_convolution(
9597
auto data_t = memory::data_type::f32;
9698
auto format_any = memory::format::any;
9799
auto format_nchw = memory::format::nchw;
98-
auto format_oihw = memory::format::oihw;
100+
auto format_weight = (g!= 1) ? memory::format::goihw : memory::format::oihw;
99101
auto format_x = memory::format::x;
100102

101103
memory::dims input_tz = {n, ic, ih, iw};
102-
memory::dims weight_tz = {oc, ic, kh, kw};
104+
memory::dims weight_tz = (g!= 1) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
103105
memory::dims bias_tz = {oc};
104106
memory::dims output_tz = {n, oc, oh, ow};
105107
memory::dims _stride = {sh, sw};
@@ -127,7 +129,7 @@ at::Tensor mkldnn_convolution(
127129

128130
auto input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
129131
input.data_ptr());
130-
auto weight_usr_memory = memory({{{weight_tz}, data_t, format_oihw}, cpu_engine},
132+
auto weight_usr_memory = memory({{{weight_tz}, data_t, format_weight}, cpu_engine},
131133
weight.data_ptr());
132134
auto output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
133135
output.data_ptr());
@@ -178,12 +180,14 @@ at::Tensor mkldnn_convolution(
178180

179181
Tensor mkldnn_convolution_backward_input(
180182
IntList input_size, const at::Tensor& grad_output, const at::Tensor& weight,
181-
IntList padding, IntList stride, IntList dilation, bool bias_defined)
183+
IntList padding, IntList stride, IntList dilation, int64_t groups, bool bias_defined)
182184
{
183185
auto grad_input = grad_output.type().tensor(input_size);
184186

185187
auto cpu_engine = CpuEngine::Instance().get_engine();
186188

189+
int32_t g = groups;
190+
187191
int32_t n = grad_input.size(0);
188192
int32_t ic = grad_input.size(1);
189193
int32_t ih = grad_input.size(2);
@@ -204,10 +208,10 @@ Tensor mkldnn_convolution_backward_input(
204208
auto data_t = memory::data_type::f32;
205209
auto format_any = memory::format::any;
206210
auto format_nchw = memory::format::nchw;
207-
auto format_oihw = memory::format::oihw;
211+
auto format_weight = (g!= 1) ? memory::format::goihw : memory::format::oihw;
208212

209213
memory::dims input_tz = {n, ic, ih, iw};
210-
memory::dims weight_tz = {oc, ic, kh, kw};
214+
memory::dims weight_tz = (g!= 1) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
211215
memory::dims bias_tz = {oc};
212216
memory::dims output_tz = {n, oc, oh, ow};
213217
memory::dims _stride = {sh, sw};
@@ -245,7 +249,7 @@ Tensor mkldnn_convolution_backward_input(
245249

246250
auto grad_output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
247251
grad_output.data_ptr());
248-
auto weight_usr_memory = memory({{{weight_tz}, data_t, format_oihw}, cpu_engine},
252+
auto weight_usr_memory = memory({{{weight_tz}, data_t, format_weight}, cpu_engine},
249253
weight.data_ptr());
250254
auto grad_input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
251255
grad_input.data_ptr());
@@ -288,7 +292,7 @@ Tensor mkldnn_convolution_backward_input(
288292

289293
std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
290294
IntList weight_size, const at::Tensor& grad_output, const at::Tensor& input,
291-
IntList padding, IntList stride, IntList dilation, bool bias_defined)
295+
IntList padding, IntList stride, IntList dilation, int64_t groups, bool bias_defined)
292296
{
293297
auto grad_weight = grad_output.type().tensor(weight_size);
294298

@@ -299,6 +303,8 @@ std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
299303

300304
auto cpu_engine = CpuEngine::Instance().get_engine();
301305

306+
int32_t g = groups;
307+
302308
int32_t n = input.size(0);
303309
int32_t ic = input.size(1);
304310
int32_t ih = input.size(2);
@@ -319,11 +325,11 @@ std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
319325
auto data_t = memory::data_type::f32;
320326
auto format_any = memory::format::any;
321327
auto format_nchw = memory::format::nchw;
322-
auto format_oihw = memory::format::oihw;
328+
auto format_weight = (g!= 1) ? memory::format::goihw : memory::format::oihw;
323329
auto format_x = memory::format::x;
324330

325331
memory::dims input_tz = {n, ic, ih, iw};
326-
memory::dims weight_tz = {oc, ic, kh, kw};
332+
memory::dims weight_tz = (g!= 1) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
327333
memory::dims bias_tz = {oc};
328334
memory::dims output_tz = {n, oc, oh, ow};
329335
memory::dims _stride = {sh, sw};
@@ -369,7 +375,7 @@ std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
369375
input.data_ptr());
370376
auto grad_output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
371377
grad_output.data_ptr());
372-
auto grad_weight_usr_memory = memory({{{weight_tz}, data_t, format_oihw}, cpu_engine},
378+
auto grad_weight_usr_memory = memory({{{weight_tz}, data_t, format_weight}, cpu_engine},
373379
grad_weight.data_ptr());
374380
std::shared_ptr<memory> grad_bias_memory;
375381

@@ -419,18 +425,18 @@ std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
419425

420426
std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward(
421427
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
422-
IntList padding, IntList stride, IntList dilation, std::array<bool,3> output_mask)
428+
IntList padding, IntList stride, IntList dilation, int64_t groups, std::array<bool,3> output_mask)
423429
{
424430
Tensor grad_output = grad_output_t.contiguous();
425431

426432
Tensor grad_input, grad_weight, grad_bias;
427433
if (output_mask[0]) {
428434
grad_input = at::mkldnn_convolution_backward_input(
429-
input.sizes(), grad_output, weight, padding, stride, dilation, output_mask[2]);
435+
input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]);
430436
}
431437
if (output_mask[1] || output_mask[2]) {
432438
std::tie(grad_weight, grad_bias) = at::mkldnn_convolution_backward_weights(
433-
weight.sizes(), grad_output, input, padding, stride, dilation, output_mask[2]);
439+
weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]);
434440
}
435441

436442
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};

aten/src/ATen/native/native_functions.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -913,16 +913,16 @@
913913

914914
- func: min_values(Tensor self, int64_t dim, bool keepdim=false) -> Tensor
915915

916-
- func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, IntList padding, IntList stride, IntList dilation) -> Tensor
916+
- func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, IntList padding, IntList stride, IntList dilation, int64_t groups) -> Tensor
917917
variants: function
918918

919-
- func: mkldnn_convolution_backward_input(IntList self_size, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, bool bias_defined) -> Tensor
919+
- func: mkldnn_convolution_backward_input(IntList self_size, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool bias_defined) -> Tensor
920920
variants: function
921921

922-
- func: mkldnn_convolution_backward_weights(IntList weight_size, Tensor grad_output, Tensor self, IntList padding, IntList stride, IntList dilation, bool bias_defined) -> (Tensor, Tensor)
922+
- func: mkldnn_convolution_backward_weights(IntList weight_size, Tensor grad_output, Tensor self, IntList padding, IntList stride, IntList dilation, int64_t groups, bool bias_defined) -> (Tensor, Tensor)
923923
variants: function
924924

925-
- func: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, std::array<bool,3> output_mask) -> (Tensor, Tensor, Tensor)
925+
- func: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, std::array<bool,3> output_mask) -> (Tensor, Tensor, Tensor)
926926
variants: function
927927

928928
- func: mm(Tensor self, Tensor mat2) -> Tensor

tools/autograd/derivatives.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,8 +1209,8 @@
12091209
input, hx, cx, weight: "_cudnn_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)"
12101210

12111211
# mkldnn
1212-
- name: mkldnn_convolution(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList stride, IntList dilation)
1213-
self, weight, bias: mkldnn_convolution_backward(self, grad, weight, padding, stride, dilation, grad_input_mask)
1212+
- name: mkldnn_convolution(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList stride, IntList dilation, int64_t groups)
1213+
self, weight, bias: mkldnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, grad_input_mask)
12141214

12151215
# fft
12161216
- name: _fft_with_size(Tensor self, int64_t signal_ndim, bool complex_input, bool complex_output, bool inverse, IntList checked_signal_sizes, bool normalized, bool onesided, IntList output_sizes)

0 commit comments

Comments
 (0)