Skip to content

Commit 5744fb3

Browse files
bddppqfacebook-github-bot
authored andcommitted
Add mkldnn softmax operator
Summary: Pull Request resolved: #21516 Differential Revision: D15712759 Pulled By: bddppq fbshipit-source-id: bf515135263156bea1a2b3e53a47edf697b8b1e2
1 parent a947d98 commit 5744fb3

File tree

3 files changed

+88
-0
lines changed

3 files changed

+88
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/Config.h>
3+
#include <ATen/NativeFunctions.h>
4+
5+
#if !AT_MKLDNN_ENABLED()
6+
7+
namespace at {
8+
namespace native {
9+
10+
Tensor mkldnn_softmax(
11+
const Tensor& self,
12+
const int64_t dim,
13+
const bool half_to_float) {
14+
AT_ERROR("mkldnn_softmax: ATen not compiled with MKLDNN support");
15+
}
16+
17+
} // namespace native
18+
} // namespace at
19+
20+
#else // AT_MKLDNN_EBABLED
21+
22+
#include <ATen/native/mkldnn/MKLDNNCommon.h>
23+
24+
namespace at {
25+
namespace native {
26+
27+
namespace {
28+
// TODO: move this to ideep
29+
struct ideep_softmax_forward
30+
: public ideep::softmax_forward,
31+
public ideep::utils::computation_cache<ideep_softmax_forward> {
32+
template <typename... Ts>
33+
ideep_softmax_forward(
34+
const ideep::tensor::descriptor& src_desc,
35+
const ideep::tensor::descriptor& dst_desc,
36+
Ts&&... args) {
37+
init(src_desc, dst_desc, std::forward<Ts>(args)...);
38+
}
39+
40+
template <class alloc>
41+
static void compute(
42+
const ideep::tensor& src,
43+
ideep::tensor& dst,
44+
int softmax_axis) {
45+
if (dst.get_descriptor() != src.get_descriptor()) {
46+
dst.reinit<alloc, ideep_softmax_forward>(src.get_descriptor());
47+
}
48+
ideep::key_t key;
49+
ideep::utils::create_key(
50+
key,
51+
src.get_data_type(),
52+
src.get_dims(),
53+
src.get_internal_format(),
54+
softmax_axis);
55+
fetch_or_create_m(
56+
comp, key, src.get_descriptor(), dst.get_descriptor(), softmax_axis);
57+
comp.execute(src, dst);
58+
}
59+
};
60+
} // namespace
61+
62+
Tensor mkldnn_softmax(
63+
const Tensor& self,
64+
const int64_t dim,
65+
const bool half_to_float) {
66+
AT_ASSERTM(
67+
!half_to_float,
68+
"softmax with half to float conversion is not supported on Mkldnn");
69+
const int64_t wrapped_dim = maybe_wrap_dim(dim, self.dim());
70+
ideep::tensor& x = itensor_from_mkldnn(self);
71+
ideep::tensor y;
72+
ideep_softmax_forward::compute<AllocForMKLDNN>(x, y, wrapped_dim);
73+
return new_with_itensor_mkldnn(std::move(y), self.options());
74+
}
75+
76+
} // namespace native
77+
} // namespace at
78+
79+
#endif // AT_MKLDNN_EBABLED

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,6 +1741,7 @@
17411741
dispatch:
17421742
CPU: softmax_cpu
17431743
CUDA: softmax_cuda
1744+
MkldnnCPU: mkldnn_softmax
17441745

17451746
- func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor
17461747
dispatch:

test/test_mkldnn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,14 @@ def test_linear(self):
260260
self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
261261
self._test_tracing(mkldnn_linear, (x.to_mkldnn(),))
262262

263+
def test_softmax(self):
264+
x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
265+
for dim in range(x.ndim):
266+
softmax = torch.nn.Softmax(dim=dim)
267+
self.assertEqual(
268+
softmax(x),
269+
softmax(x.to_mkldnn()).to_dense())
270+
263271
def test_sigmoid(self):
264272
x = torch.randn(4, 5, dtype=torch.float32) * 10
265273
mkldnn_x = x.to_mkldnn()

0 commit comments

Comments
 (0)