|
| 1 | +#include <ATen/ATen.h> |
| 2 | +#include <ATen/Config.h> |
| 3 | +#include <ATen/NativeFunctions.h> |
| 4 | + |
| 5 | +#if !AT_MKLDNN_ENABLED() |
| 6 | + |
| 7 | +namespace at { |
| 8 | +namespace native { |
| 9 | + |
| 10 | +Tensor mkldnn_softmax( |
| 11 | + const Tensor& self, |
| 12 | + const int64_t dim, |
| 13 | + const bool half_to_float) { |
| 14 | + AT_ERROR("mkldnn_softmax: ATen not compiled with MKLDNN support"); |
| 15 | +} |
| 16 | + |
| 17 | +} // namespace native |
| 18 | +} // namespace at |
| 19 | + |
| 20 | +#else // AT_MKLDNN_EBABLED |
| 21 | + |
| 22 | +#include <ATen/native/mkldnn/MKLDNNCommon.h> |
| 23 | + |
| 24 | +namespace at { |
| 25 | +namespace native { |
| 26 | + |
| 27 | +namespace { |
| 28 | +// TODO: move this to ideep |
| 29 | +struct ideep_softmax_forward |
| 30 | + : public ideep::softmax_forward, |
| 31 | + public ideep::utils::computation_cache<ideep_softmax_forward> { |
| 32 | + template <typename... Ts> |
| 33 | + ideep_softmax_forward( |
| 34 | + const ideep::tensor::descriptor& src_desc, |
| 35 | + const ideep::tensor::descriptor& dst_desc, |
| 36 | + Ts&&... args) { |
| 37 | + init(src_desc, dst_desc, std::forward<Ts>(args)...); |
| 38 | + } |
| 39 | + |
| 40 | + template <class alloc> |
| 41 | + static void compute( |
| 42 | + const ideep::tensor& src, |
| 43 | + ideep::tensor& dst, |
| 44 | + int softmax_axis) { |
| 45 | + if (dst.get_descriptor() != src.get_descriptor()) { |
| 46 | + dst.reinit<alloc, ideep_softmax_forward>(src.get_descriptor()); |
| 47 | + } |
| 48 | + ideep::key_t key; |
| 49 | + ideep::utils::create_key( |
| 50 | + key, |
| 51 | + src.get_data_type(), |
| 52 | + src.get_dims(), |
| 53 | + src.get_internal_format(), |
| 54 | + softmax_axis); |
| 55 | + fetch_or_create_m( |
| 56 | + comp, key, src.get_descriptor(), dst.get_descriptor(), softmax_axis); |
| 57 | + comp.execute(src, dst); |
| 58 | + } |
| 59 | +}; |
| 60 | +} // namespace |
| 61 | + |
| 62 | +Tensor mkldnn_softmax( |
| 63 | + const Tensor& self, |
| 64 | + const int64_t dim, |
| 65 | + const bool half_to_float) { |
| 66 | + AT_ASSERTM( |
| 67 | + !half_to_float, |
| 68 | + "softmax with half to float conversion is not supported on Mkldnn"); |
| 69 | + const int64_t wrapped_dim = maybe_wrap_dim(dim, self.dim()); |
| 70 | + ideep::tensor& x = itensor_from_mkldnn(self); |
| 71 | + ideep::tensor y; |
| 72 | + ideep_softmax_forward::compute<AllocForMKLDNN>(x, y, wrapped_dim); |
| 73 | + return new_with_itensor_mkldnn(std::move(y), self.options()); |
| 74 | +} |
| 75 | + |
| 76 | +} // namespace native |
| 77 | +} // namespace at |
| 78 | + |
| 79 | +#endif // AT_MKLDNN_EBABLED |
0 commit comments