Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions aten/src/ATen/native/mkldnn/UnaryOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <ATen/NativeFunctions.h>

#if !AT_MKLDNN_ENABLED()

namespace at {
namespace native {

Tensor mkldnn_sigmoid(const Tensor& self) {
AT_ERROR("mkldnn_sigmoid: ATen not compiled with MKLDNN support");
}

Tensor& mkldnn_sigmoid_(Tensor& self) {
AT_ERROR("mkldnn_sigmoid_: ATen not compiled with MKLDNN support");
}

} // namespace native
} // namespace at

#else // AT_MKLDNN_EBABLED

#include <ATen/native/mkldnn/MKLDNNCommon.h>

namespace at {
namespace native {

Tensor mkldnn_sigmoid(const Tensor& self) {
ideep::tensor& x = itensor_from_mkldnn(self);
ideep::tensor y;
ideep::eltwise_forward::compute(
x, y, ideep::algorithm::eltwise_logistic, ideep::prop_kind::forward);
return new_with_itensor_mkldnn(std::move(y), self.options());
}

Tensor& mkldnn_sigmoid_(Tensor& self) {
ideep::tensor& x = itensor_from_mkldnn(self);
ideep::eltwise_forward::compute(
x, x, ideep::algorithm::eltwise_logistic, ideep::prop_kind::forward);
return self;
}

} // namespace native
} // namespace at

#endif // AT_MKLDNN_EBABLED
5 changes: 5 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1590,12 +1590,17 @@

- func: sigmoid(Tensor self) -> Tensor
variants: function, method
dispatch:
CPU: sigmoid
CUDA: sigmoid
MkldnnCPU: mkldnn_sigmoid

- func: sigmoid_(Tensor(a!) self) -> Tensor(a!)
variants: function, method
dispatch:
CPU: _sigmoid__cpu
CUDA: _sigmoid__cuda
MkldnnCPU: mkldnn_sigmoid_

- func: sigmoid(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Expand Down
12 changes: 12 additions & 0 deletions test/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,18 @@ def test_linear(self):
self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
self._test_tracing(mkldnn_linear, (x.to_mkldnn(),))

def test_sigmoid(self):
x = torch.randn(4, 5, dtype=torch.float32) * 10
mkldnn_x = x.to_mkldnn()
self.assertEqual(
torch.sigmoid(x),
torch.sigmoid(mkldnn_x).to_dense(),
)
# inplace
torch.sigmoid_(x)
torch.sigmoid_(mkldnn_x)
self.assertEqual(x, mkldnn_x.to_dense())

def _test_serialization(self, module, inputs):
with TemporaryFileName() as fname:
torch.jit.save(module, fname)
Expand Down