Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions aten/src/ATen/native/mkldnn/SoftMax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <ATen/NativeFunctions.h>

#if !AT_MKLDNN_ENABLED()

namespace at {
namespace native {

Tensor mkldnn_softmax(
const Tensor& self,
const int64_t dim,
const bool half_to_float) {
AT_ERROR("mkldnn_softmax: ATen not compiled with MKLDNN support");
}

} // namespace native
} // namespace at

#else // AT_MKLDNN_EBABLED

#include <ATen/native/mkldnn/MKLDNNCommon.h>

namespace at {
namespace native {

namespace {
// TODO: move this to ideep
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

struct ideep_softmax_forward
: public ideep::softmax_forward,
public ideep::utils::computation_cache<ideep_softmax_forward> {
template <typename... Ts>
ideep_softmax_forward(
const ideep::tensor::descriptor& src_desc,
const ideep::tensor::descriptor& dst_desc,
Ts&&... args) {
init(src_desc, dst_desc, std::forward<Ts>(args)...);
}

template <class alloc>
static void compute(
const ideep::tensor& src,
ideep::tensor& dst,
int softmax_axis) {
if (dst.get_descriptor() != src.get_descriptor()) {
dst.reinit<alloc, ideep_softmax_forward>(src.get_descriptor());
}
ideep::key_t key;
ideep::utils::create_key(
key,
src.get_data_type(),
src.get_dims(),
src.get_internal_format(),
softmax_axis);
fetch_or_create_m(
comp, key, src.get_descriptor(), dst.get_descriptor(), softmax_axis);
comp.execute(src, dst);
}
};
} // namespace

Tensor mkldnn_softmax(
const Tensor& self,
const int64_t dim,
const bool half_to_float) {
AT_ASSERTM(
!half_to_float,
"softmax with half to float conversion is not supported on Mkldnn");
const int64_t wrapped_dim = maybe_wrap_dim(dim, self.dim());
ideep::tensor& x = itensor_from_mkldnn(self);
ideep::tensor y;
ideep_softmax_forward::compute<AllocForMKLDNN>(x, y, wrapped_dim);
return new_with_itensor_mkldnn(std::move(y), self.options());
}

} // namespace native
} // namespace at

#endif // AT_MKLDNN_EBABLED
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1741,6 +1741,7 @@
dispatch:
CPU: softmax_cpu
CUDA: softmax_cuda
MkldnnCPU: mkldnn_softmax

- func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor
dispatch:
Expand Down
8 changes: 8 additions & 0 deletions test/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,14 @@ def test_linear(self):
self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
self._test_tracing(mkldnn_linear, (x.to_mkldnn(),))

def test_softmax(self):
x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
for dim in range(x.ndim):
softmax = torch.nn.Softmax(dim=dim)
self.assertEqual(
softmax(x),
softmax(x.to_mkldnn()).to_dense())

def test_sigmoid(self):
x = torch.randn(4, 5, dtype=torch.float32) * 10
mkldnn_x = x.to_mkldnn()
Expand Down