Skip to content

Conversation

@XiaobingSuper
Copy link
Collaborator

fix #40391, we will throw error for mkldnn module when mkldnn disabled.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is bringing us back to mkldnn not being called with dilation. Should this PR be landed on top of the one that allows convolutions with dilation to use mkldnn?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable it at #40483

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is convolution behavior wrt mkl or dense inputs? Similar to cases you posted above for linear and bn.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After #40610 torch.convolution won't be calling mkldnn in a fair number of cases, so you still need to call mkldnn_convolution here.

Copy link
Collaborator Author

@XiaobingSuper XiaobingSuper Jul 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel, I don't think need to call mkldnn_convolution here, if input and weight are MKLDNN tensor, it will call mkldnn_convolution directly, see

auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool {
#if AT_MKLDNN_ENABLED()
if (!at::globalContext().userEnabledMkldnn()) {
return false;
}
return (input.is_mkldnn()) || // input is mkldnn Tensor
(input.options().backend() == at::Backend::CPU &&
input.scalar_type() == kFloat && // only on CPU Float Tensors
!transposed && // or transposed tensors
input.ndimension() == 4); // must be in NCHW format
. If users set input to MKLDNN tensor, assuming they want to call MKLDNN path.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it enough to check only for mkldnn inputs? Currently the following (which, as far as I understand, sends regular tensor to mkldnn convolution) works:

x = torch.randn(N, C, 56, 56, dtype=torch.float)#.to_mkldnn()
conv2d = torch.nn.Conv2d(in_channels=C,
                            out_channels=M,
                            kernel_size=3,
                            padding=1)
mkldnn_conv2d = mkldnn_utils.to_mkldnn(copy.deepcopy(conv2d))
with torch.backends.mkldnn.flags(enabled=True):
    y_mkldnn = mkldnn_conv2d(x)#.to_dense()

Same question for all other functions - you added checks only for mkldnn inputs.

Copy link
Collaborator Author

@XiaobingSuper XiaobingSuper Jun 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel , for convolution, this input check can be removed, see

  1. torch.backends.mkldnn.flags(enabled=True), input and weight are MKLNDN tensor, it is ok.
  2. torch.backends.mkldnn.flags(enabled=True), input is dense tensor, and weight are MKLNDN tensor, will call MKLDNN path, there will make check to make sure input and weight have same type: see
    TORCH_CHECK(input.options().type_equal(weight.options()),
    "Input type (", input.toString(), ") and weight type (", weight.toString(),
    ") should be the same");
    , so this will report a error: Input type (torch.FloatTensor) and weight type (Mkldnntorch.FloatTensor) should be the same.
  3. torch.backends.mkldnn.flags(enabled=False), input and weight are MKLNDN tensor, will call this path
    output = at::_convolution_nogroup(
    input.contiguous(), weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding);
    , which not support MKLDNN inputs. So this will report a error: "opaque tensors do not have is_contiguous".
  4. torch.backends.mkldnn.flags(enabled=False), input is dense tensor, weight are MKLNDN tensor, also call path
    output = at::_convolution_nogroup(
    input.contiguous(), weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding);

    this will report a error: Cannot access data pointer of Tensor that doesn't have storage.

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 23, 2020
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For linear:

x = torch.randn(56, 56, dtype=torch.float)
linear = torch.nn.Linear(56, 56 bias=True)
mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
with torch.backends.mkldnn.flags(enabled=True):
    y_mkldnn = mkldnn_linear(x)
  1. torch.backends.mkldnn.flags(enabled=True), input and weight are MKLDNN tensor, it can works.
  2. torch.backends.mkldnn.flags(enabled=True), input is dense tensor, weight are MKLDNN tensor, it can works, because we will first transfer input to MLDNN tensor in
    x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
    .
  3. torch.backends.mkldnn.flags(enabled=False), input and weight are MKLDNN tensor, it can report error: MKLDNN linear module can't run when mkldnn disabled
  4. torch.backends.mkldnn.flags(enabled=False), input is dense tensor, weight are MKLDNN tensor, it can also report error: MKLDNN linear module can't run when mkldnn disabled.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For BatchNorm,

x = torch.randn(1, 3, 56, 56, dtype=torch.float)
bn = torch.nn.BatchNorm2d(3).float().train(False)
mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn))
with torch.backends.mkldnn.flags(enabled=True):
    y_mkldnn = bn(x)
  1. torch.backends.mkldnn.flags(enabled=True), input and weight are MKLDNN tensor, it can works.
  2. torch.backends.mkldnn.flags(enabled=True), input is dense tensor, weight are MKLDNN tensor, it report a error:mkldnn_tensor.is_mkldnn() INTERNAL ASSERT FAILED at "../aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp":56, please report a bug to PyTorch. mkldnn_to_dense expects MKL-DNN tensor input.
  3. torch.backends.mkldnn.flags(enabled=False), input and weight are MKLDNN tensor, it will report a error:MKLDNN batch_norm module can't run when mkldnn is disabled.
  4. torch.backends.mkldnn.flags(enabled=False), input is dense tensor, weight are MKLDNN tensor, it will report a error:mkldnn_tensor.is_mkldnn() INTERNAL ASSERT FAILED at "../aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp":56, please report a bug to PyTorch. mkldnn_to_dense expects MKL-DNN tensor input.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should not be INTERNAL_ASSERTs here, right? Is it a real pytorch bug that you guys have to fix, or should the error message be different. Couple additional questions

  1. why is mkldnn_to_dense attempted here, is it when converting output of the operation to dense tensor?
  2. Why is behavior here different than in linear, where case 2 works, and case 4 errors out as expected, with expected error message?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first one, it call in

ideep::tensor& itensor_from_mkldnn(const MKLDNNTensor& mkldnn_tensor) {
, which the error message is wrong, it shold be "itensor_from_mkldnn expects MKL-DNN tensor input"

For the second: the case 2 can be works for linear is that we always convert dense tensor to MKLDNN tensor for input, see

x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
. the case 4's error message need make some changes to be unified.

@facebook-github-bot
Copy link
Contributor

Hi @XiaobingSuper!

Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but we do not have a signature on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mkldnn modules should throw RuntimeErrors if invoked when mkldnn is disabled

5 participants