Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aten/src/ATen/native/LegacyDefinitions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include <ATen/NativeFunctions.h>
#include <ATen/LegacyTHFunctions.h>

#include <ATen/native/mkldnn/TensorShape.h>

namespace at { namespace native {

// Methods
Expand Down Expand Up @@ -63,6 +65,9 @@ Tensor & masked_scatter_(Tensor& self, const Tensor & mask, const Tensor & sourc
}

Tensor view(const Tensor& self, IntArrayRef size) {
if (self.is_mkldnn()) {
return mkldnn_view(self, size);
}
return at::legacy::th::_th_view(self, size);
}

Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
AT_ERROR("reshape is not implemented for sparse tensors");
}
auto shape = infer_size(proposed_shape, self.numel());

if (self.is_mkldnn()) {
return at::mkldnn_reshape(self, shape);
}

if (auto stride = THTensor_compute_stride(self.sizes(), self.strides(), shape)) {
return self.as_strided(shape, *stride);
}
Expand Down
46 changes: 46 additions & 0 deletions aten/src/ATen/native/mkldnn/TensorShape.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <ATen/InferSize.h>
#include <ATen/NativeFunctions.h>

#if !AT_MKLDNN_ENABLED()

namespace at {
namespace native {

Tensor mkldnn_view(const Tensor& self, IntArrayRef size) {
AT_ERROR("mkldnn_reshape: ATen not compiled with MKLDNN support");
}

Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) {
AT_ERROR("mkldnn_reshape: ATen not compiled with MKLDNN support");
}

} // namespace native
} // namespace at

#else // AT_MKLDNN_EBABLED

#include <ATen/native/mkldnn/MKLDNNCommon.h>

namespace at {
namespace native {

Tensor mkldnn_view(const Tensor& self, IntArrayRef size) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I guess you can move it out of the #ifdef as the error is regardless of whether mkldnn is compiled in

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error messages are different though so I would prefer to keep two these two places.

AT_ERROR(
"Currently Mkldnn tensor does not support view. Change to use reshape instead");
}

Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) {
auto inferred_size = at::infer_size(size, self.numel());
const ideep::tensor& x = itensor_from_mkldnn(self);
ideep::tensor y;
ideep::direct_copy::compute<AllocForMKLDNN>(x, y);
y.reshape({inferred_size.cbegin(), inferred_size.cend()});
return new_with_itensor_mkldnn(std::move(y), self.options());
}

} // namespace native
} // namespace at

#endif // AT_MKLDNN_EBABLED
11 changes: 11 additions & 0 deletions aten/src/ATen/native/mkldnn/TensorShape.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include <ATen/ATen.h>

namespace at {
namespace native {

Tensor mkldnn_view(const Tensor& self, IntArrayRef size);

} // namespace native
} // namespace at
6 changes: 6 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,12 @@
variants: function, method
device_guard: False

- func: mkldnn_reshape(Tensor self, int[] shape) -> Tensor
device_guard: False
requires_tensor: True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need requires_tensor if you have dispatch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep I know this, just for documentation purpose :-)

dispatch:
MkldnnCPU: mkldnn_reshape

- func: reshape_as(Tensor self, Tensor other) -> Tensor
variants: method
device_guard: False
Expand Down
15 changes: 15 additions & 0 deletions test/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,21 @@ def test_add(self):
torch.add(mx, my, alpha=alpha, out=mkldnn_out)
self.assertEqual(out, mkldnn_out.to_dense())

def test_view(self):
x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn()
self.assertRaisesRegex(RuntimeError,
"Change to use reshape",
lambda: x.view(x.size(0), -1))

def test_reshape(self):
x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
size = (x.size(0), -1)

self.assertEqual(
x.reshape(size),
x.to_mkldnn().reshape(size).to_dense(),
)


if __name__ == '__main__':
run_tests()