Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool {
return (input.is_mkldnn()) || // input is mkldnn Tensor
(input.options().backend() == at::Backend::CPU &&
input.scalar_type() == kFloat && // only on CPU Float Tensors
!transposed && // or transposed tensors
input.ndimension() == 4); // must be in NCHW format
!transposed // or transposed tensors
);
#endif
return false;
}
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/mkldnn/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ inline ideep::tensor get_mkldnn_tensor(const at::Tensor& tensor) {

namespace at { namespace native {

ideep::tensor _mkldnn_conv2d(
ideep::tensor _mkldnn_convolution(
const ideep::tensor& x,
const ideep::tensor& w,
const c10::optional<ideep::tensor>& b,
Expand Down Expand Up @@ -113,7 +113,7 @@ at::Tensor mkldnn_convolution(
mkldnn_bias = get_mkldnn_tensor(bias);
}

ideep::tensor mkldnn_output = _mkldnn_conv2d(
ideep::tensor mkldnn_output = _mkldnn_convolution(
mkldnn_input,
mkldnn_weight,
mkldnn_bias,
Expand Down
10 changes: 5 additions & 5 deletions aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,21 @@ Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options)
}

ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) {
AT_ASSERTM(mkldnn_tensor.is_mkldnn(),
"mkldnn_to_dense expects MKL-DNN tensor input");
TORCH_CHECK(mkldnn_tensor.is_mkldnn(),
"itensor_from_mkldnn expects MKL-DNN tensor input");
TORCH_INTERNAL_ASSERT(at::impl::variable_excluded_from_dispatch());
MKLDNNTensorImpl *mklimpl = static_cast<MKLDNNTensorImpl *>(mkldnn_tensor.unsafeGetTensorImpl());
return mklimpl->unsafe_opaque_handle()->get_target();
}

ideep::tensor itensor_view_from_dense(const Tensor& tensor) {
AT_ASSERTM(
TORCH_CHECK(
tensor.device().type() == DeviceType::CPU,
"itensor_view_from_dense expects CPU tensor input");
AT_ASSERTM(
TORCH_CHECK(
tensor.layout() == Layout::Strided,
"itensor_view_from_dense expects dense tensor input");
AT_ASSERTM(tensor.scalar_type() == ScalarType::Float,
TORCH_CHECK(tensor.scalar_type() == ScalarType::Float,
"itensor_view_from_dense expects float tensor input");
TORCH_INTERNAL_ASSERT(at::impl::variable_excluded_from_dispatch());
return {{{tensor.sizes().cbegin(), tensor.sizes().cend()},
Expand Down
61 changes: 46 additions & 15 deletions aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor) {
}

Tensor dense_to_mkldnn(const Tensor& cpu_tensor) {
AT_ASSERTM(cpu_tensor.device().type() == DeviceType::CPU,
TORCH_CHECK(cpu_tensor.device().type() == DeviceType::CPU,
"dense_to_mkldnn expects CPU tensor input");
AT_ASSERTM(cpu_tensor.layout() == Layout::Strided,
TORCH_CHECK(cpu_tensor.layout() == Layout::Strided,
"dense_to_mkldnn expects strided tensor input");
AT_ASSERTM(cpu_tensor.scalar_type() == ScalarType::Float,
TORCH_CHECK(cpu_tensor.scalar_type() == ScalarType::Float,
"dense_to_mkldnn expects float tensor input");
AT_ASSERTM(cpu_tensor.dim() <= 5,
TORCH_CHECK(cpu_tensor.dim() <= 5,
"Can't convert cpu tensor with the number of dimensions > 5");
// TODO: consider to convert non-contiguous tensor to `ideep::tensor` directly.
auto cpu_tensor_cont = cpu_tensor.contiguous();
Expand All @@ -53,10 +53,6 @@ Tensor mkldnn_reorder_conv2d_weight(
IntArrayRef dilation,
int64_t groups) {

auto stride_vec = expand_param_if_needed(stride, "stride", 2);
auto padding_vec = expand_param_if_needed(padding, "padding", 2);
auto dilation_vec = expand_param_if_needed(dilation, "dilation", 2);

auto w = itensor_from_mkldnn(self);

// Legacy mkldnn conv2d jitted module may contain a 5-d weight with an extra
Expand All @@ -73,10 +69,36 @@ Tensor mkldnn_reorder_conv2d_weight(
ideep::convolution_forward::expected_weights_desc(
w.get_dims(),
w.get_data_type(),
{stride_vec.cbegin(), stride_vec.cend()},
{padding_vec.cbegin(), padding_vec.cend()},
{padding_vec.cbegin(), padding_vec.cend()},
{dilation_vec.cbegin(), dilation_vec.cend()},
{stride.begin(), stride.end()},
{padding.begin(), padding.end()},
{padding.begin(), padding.end()},
{dilation.begin(), dilation.end()},
groups,
ideep::algorithm::convolution_direct);
ideep::tensor result;
result.init(desc);
result.feed_from(w);

return new_with_itensor_mkldnn(std::move(result), self.options());
}

Tensor mkldnn_reorder_conv3d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {

auto w = itensor_from_mkldnn(self);

auto desc =
ideep::convolution_forward::expected_weights_desc(
w.get_dims(),
w.get_data_type(),
{stride.begin(), stride.end()},
{padding.begin(), padding.end()},
{padding.begin(), padding.end()},
{dilation.begin(), dilation.end()},
groups,
ideep::algorithm::convolution_direct);
ideep::tensor result;
Expand All @@ -89,11 +111,11 @@ Tensor mkldnn_reorder_conv2d_weight(
#else

Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor) {
AT_ERROR("MKL-DNN build is disabled");
TORCH_CHECK(false, "MKL-DNN build is disabled");
}

Tensor dense_to_mkldnn(const Tensor& cpu_tensor) {
AT_ERROR("MKL-DNN build is disabled");
TORCH_CHECK(false, "MKL-DNN build is disabled");
}

Tensor mkldnn_reorder_conv2d_weight(
Expand All @@ -102,7 +124,16 @@ Tensor mkldnn_reorder_conv2d_weight(
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
AT_ERROR("mkldnn_reorder_conv2d_weight: MKL-DNN build is disabled");
TORCH_CHECK(false, "mkldnn_reorder_conv2d_weight: MKL-DNN build is disabled");
}

Tensor mkldnn_reorder_conv3d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
TORCH_CHECK(false, "mkldnn_reorder_conv3d_weight: MKL-DNN build is disabled");
}

#endif // AT_MKLDNN_ENABLED()
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3508,6 +3508,13 @@
dispatch:
MkldnnCPU: mkldnn_reorder_conv2d_weight

- func: mkldnn_reorder_conv3d_weight(Tensor self, int[3] padding=0, int[3] stride=1, int[3] dilation=1, int groups=1) -> Tensor
use_c10_dispatcher: full
variants: function
python_module: nn
dispatch:
MkldnnCPU: mkldnn_reorder_conv3d_weight

- func: to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor
use_c10_dispatcher: full

Expand Down
23 changes: 23 additions & 0 deletions test/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,29 @@ def test_conv2d_legacy_jit_model(self):
conv2d(x),
conv2d_loaded(x.to_mkldnn()).to_dense())

def test_conv3d(self):
for groups in [1, 4]:
N = torch.randint(3, 10, (1,)).item()
C = torch.randint(1, 3, (1,)).item() * groups
M = torch.randint(1, 3, (1,)).item() * groups
x = torch.randn(N, C, 55, 55, 55, dtype=torch.float32)
for bias in [True, False]:
conv3d = torch.nn.Conv3d(in_channels=C,
out_channels=M,
kernel_size=3,
stride=2,
padding=1,
bias=bias,
groups=groups).float()
mkldnn_conv3d = mkldnn_utils.to_mkldnn(copy.deepcopy(conv3d))
with torch.backends.mkldnn.flags(enabled=False):
y_aten = conv3d(x)
y_mkldnn = mkldnn_conv3d(x.to_mkldnn()).to_dense()
self.assertEqual(y_aten, y_mkldnn)

self._test_serialization(mkldnn_conv3d, (x.to_mkldnn(),))
self._test_tracing(mkldnn_conv3d, (x.to_mkldnn(),))

def test_relu(self):
x = torch.randn((4, 5), dtype=torch.float32) * 10
self.assertEqual(torch.relu(x), torch.relu(x.to_mkldnn()).to_dense())
Expand Down
4 changes: 2 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4106,7 +4106,7 @@ def test_Conv3d_groups_nobias(self):
atol=dtype2prec_DONTUSE[torch.float], rtol=0)
self.assertEqual(m.weight.grad.data,
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
atol=dtype2prec_DONTUSE[torch.float], rtol=0)
atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float])

def test_Conv3d_groups_wbias(self):
torch.manual_seed(123)
Expand Down Expand Up @@ -4139,7 +4139,7 @@ def test_Conv3d_groups_wbias(self):
atol=dtype2prec_DONTUSE[torch.float], rtol=0)
self.assertEqual(m.bias.grad.data,
torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
atol=dtype2prec_DONTUSE[torch.float], rtol=0)
atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float])

# Very similar to test_Conv2d_naive_groups but with special care to handle
# the number of groups == number of input channels
Expand Down
24 changes: 24 additions & 0 deletions torch/utils/mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,28 @@ def __setstate__(self, state):
self.bias = state[1].to_mkldnn()
self.training = state[2]

class MkldnnConv3d(_MkldnnConvNd):
def __init__(self, dense_module):
super(MkldnnConv3d, self).__init__(dense_module)

self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv3d_weight(
dense_module.weight.to_mkldnn(),
self.padding,
self.stride,
self.dilation,
self.groups))

@torch.jit.script_method
def __setstate__(self, state):
self.weight = torch._C._nn.mkldnn_reorder_conv3d_weight(
state[0].to_mkldnn(),
self.padding,
self.stride,
self.dilation,
self.groups)
self.bias = state[1].to_mkldnn()
self.training = state[2]


class MkldnnBatchNorm2d(torch.jit.ScriptModule):
__constants__ = ['exponential_average_factor', 'eps']
Expand Down Expand Up @@ -165,6 +187,8 @@ def m_fn(m):
return MkldnnConv1d(m)
elif isinstance(m, torch.nn.Conv2d):
return MkldnnConv2d(m)
elif isinstance(m, torch.nn.Conv3d):
return MkldnnConv3d(m)
elif isinstance(m, torch.nn.BatchNorm2d):
return MkldnnBatchNorm2d(m)
else:
Expand Down