Skip to content

Commit b599bb3

Browse files
XiaobingSuperfacebook-github-bot
authored andcommitted
Add mkldnn mul operator (#20575)
Summary: ### mkldnn backward ops list: - [ ] \(#20567) Add aten mkldnn conv2d backward operator 💛 - [ ] \(#20570) Add aten mkldnn backward ops: relu, linear and reshape 💛 - [ ] \(#20571) Add aten mkldnn backward ops: max_pool2d, avg_pool2d and adaptive_avg_poo2d 💛 - [ ] \(#20572) Add aten mkldnn batchnorm backward operator 💛 - [ ] \(#20573) Add aten mkldnn zero_ operator:yellow_heart: - [ ] \(#20575) Add mkldnn mul operator 💛 Pull Request resolved: #20575 Differential Revision: D15799529 Pulled By: bddppq fbshipit-source-id: 4887d8ef1a0e316ad9db199b657d9481fc13e486
1 parent d3b3cbe commit b599bb3

File tree

3 files changed

+112
-0
lines changed

3 files changed

+112
-0
lines changed

aten/src/ATen/native/mkldnn/BinaryOps.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@ Tensor& mkldnn_add_(Tensor& self, const Tensor& other, Scalar alpha) {
2323
AT_ERROR("mkldnn_add_: ATen not compiled with MKLDNN support");
2424
}
2525

26+
Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other) {
27+
AT_ERROR("mkldnn_mul_out: ATen not compiled with MKLDNN support");
28+
}
29+
30+
Tensor mkldnn_mul(const Tensor& self, const Tensor& other) {
31+
AT_ERROR("mkldnn_mul: ATen not compiled with MKLDNN support");
32+
}
33+
34+
Tensor& mkldnn_mul_(Tensor& self, const Tensor& other) {
35+
AT_ERROR("mkldnn_mul_: ATen not compiled with MKLDNN support");
36+
}
37+
2638
} // namespace native
2739
} // namespace at
2840

@@ -63,6 +75,38 @@ Tensor& mkldnn_add_(Tensor& self, const Tensor& other, Scalar alpha) {
6375
return native::mkldnn_add_out(self, self, other, alpha);
6476
}
6577

78+
Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other) {
79+
AT_ASSERTM(result.sizes() == self.sizes(),
80+
"mkldnn_mul_out: the output size should be same as input size");
81+
ideep::tensor& z = itensor_from_mkldnn(result);
82+
ideep::tensor& x = itensor_from_mkldnn(self);
83+
84+
// for zero_dim tensor
85+
if (other.ndimension() == 0) {
86+
ideep::eltwise_forward::compute<AllocForMKLDNN>(
87+
x, z, ideep::algorithm::eltwise_linear,
88+
ideep::prop_kind::forward_inference, /*alpha*/ other.item().to<float>());
89+
90+
return result;
91+
} else {
92+
AT_ASSERTM(self.sizes() == other.sizes(),
93+
"mkldnn_mul_out: currently mkldnn not support broadcasting");
94+
ideep::tensor y = itensor_from_mkldnn(other);
95+
auto op = ideep::eltwise_binary::eltwise_binary_op::ELTWISE_MUL;
96+
ideep::eltwise_binary::compute<AllocForMKLDNN>(op, x, y, z);
97+
98+
return result;
99+
}
100+
}
101+
102+
Tensor mkldnn_mul(const Tensor& self, const Tensor& other) {
103+
Tensor result = empty_mkldnn(self.sizes(), self.options());
104+
return native::mkldnn_mul_out(result, self, other);
105+
}
106+
107+
Tensor& mkldnn_mul_(Tensor& self, const Tensor& other) {
108+
return native::mkldnn_mul_out(self, self, other);
109+
}
66110

67111
} // namespace native
68112
} // namespace at

aten/src/ATen/native/native_functions.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,11 +1303,30 @@
13031303

13041304
- func: mul(Tensor self, Tensor other) -> Tensor
13051305
variants: function, method
1306+
dispatch:
1307+
CPU: mul
1308+
CUDA: mul
1309+
SparseCPU: mul
1310+
SparseCUDA: mul
1311+
MkldnnCPU: mkldnn_mul
1312+
13061313

13071314
- func: mul_(Tensor(a!) self, Tensor other) -> Tensor(a!)
13081315
variants: method
1316+
dispatch:
1317+
CPU: mul_
1318+
CUDA: mul_
1319+
SparseCPU: mul_
1320+
SparseCUDA: mul_
1321+
MkldnnCPU: mkldnn_mul_
13091322

13101323
- func: mul(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
1324+
dispatch:
1325+
CPU: mul_out
1326+
CUDA: mul_out
1327+
SparseCPU: mul_out
1328+
SparseCUDA: mul_out
1329+
MkldnnCPU: mkldnn_mul_out
13111330

13121331
# For C++ only, until we have conversion from C++ numbers to Tensor
13131332
- func: mul(Tensor self, Scalar other) -> Tensor

test/test_mkldnn.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,55 @@ def test_add(self):
216216
torch.add(mx, my, alpha=alpha, out=mkldnn_out)
217217
self.assertEqual(out, mkldnn_out.to_dense())
218218

219+
def test_mul(self):
220+
N = torch.randint(3, 10, (1,)).item()
221+
C = torch.randint(3, 100, (1,)).item()
222+
value = torch.randn(1, dtype=torch.float32).item()
223+
224+
x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
225+
y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
226+
mx = x.to_mkldnn()
227+
my = y.to_mkldnn()
228+
229+
# mul
230+
self.assertEqual(
231+
x * y,
232+
(mx * my).to_dense())
233+
234+
self.assertEqual(
235+
x * value,
236+
(mx * value).to_dense())
237+
238+
self.assertEqual(
239+
torch.mul(x, y),
240+
torch.mul(mx, my).to_dense())
241+
242+
self.assertEqual(
243+
torch.mul(x, value),
244+
torch.mul(mx, value).to_dense())
245+
246+
# mul_
247+
x *= y
248+
mx *= my
249+
self.assertEqual(x, mx.to_dense())
250+
251+
x *= value
252+
mx *= value
253+
self.assertEqual(x, mx.to_dense())
254+
255+
# mul_out
256+
out = x.clone()
257+
mkldnn_out = out.to_mkldnn()
258+
torch.mul(x, y, out=out)
259+
torch.mul(mx, my, out=mkldnn_out)
260+
self.assertEqual(out, mkldnn_out.to_dense())
261+
262+
out = x.clone()
263+
mkldnn_out = out.to_mkldnn()
264+
torch.mul(x, value, out=out)
265+
torch.mul(mx, value, out=mkldnn_out)
266+
self.assertEqual(out, mkldnn_out.to_dense())
267+
219268
def test_view(self):
220269
x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn()
221270
self.assertRaisesRegex(RuntimeError,

0 commit comments

Comments
 (0)