Skip to content

Conversation

@XiaobingSuper
Copy link
Collaborator

@XiaobingSuper XiaobingSuper commented Mar 30, 2020

Stack from ghstack:

Differential Revision: D22102408

[ghstack-poisoned]
@dr-ci
Copy link

dr-ci bot commented Mar 30, 2020

💊 CI failures summary and remediations

As of commit 91a0981 (more details on the Dr. CI page):



❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_bazel_test (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun) ❄️

Jun 19 02:23:59 TIMEOUT: //:integration_test (Summary)
Jun 19 02:23:59              for (int k = 0; k < cross_chunk_shuffle_count; ++k) { 
Jun 19 02:23:59                              ~~^~~~~~~~~~~~~~~~~~~~~~~~~~~ 
Jun 19 02:23:59 test/cpp/api/dataloader.cpp:2204:13: warning: unused variable 'offset' [-Wunused-variable] 
Jun 19 02:23:59          int offset = 0; 
Jun 19 02:23:59              ^~~~~~ 
Jun 19 02:23:59 test/cpp/api/dataloader.cpp: In member function 'virtual void DataLoaderTest_CustomPreprocessPolicy_Test::TestBody()': 
Jun 19 02:23:59 test/cpp/api/dataloader.cpp:2294:29: warning: comparison between signed and unsigned integer expressions [-Wsign-compare] 
Jun 19 02:23:59            for (int i = 0; i < batch_result.size(); i += chunk_size) { 
Jun 19 02:23:59                            ~~^~~~~~~~~~~~~~~~~~~~~ 
Jun 19 02:23:59  
Jun 19 02:23:59 TIMEOUT: //:integration_test (Summary) 
Jun 19 02:23:59       /var/lib/jenkins/.cache/bazel/_bazel_jenkins/fdf6d09bf4b4f04a71e2a7dfceb40620/execroot/pytorch/bazel-out/k8-fastbuild/testlogs/integration_test/test.log 
Jun 19 02:23:59 INFO: From Testing //:integration_test: 
Jun 19 02:23:59 ==================== Test output for //:integration_test: 
Jun 19 02:23:59 Running main() from gmock_main.cc 
Jun 19 02:23:59 Note: Google Test filter = -*CUDA 
Jun 19 02:23:59 [==========] Running 1 test from 1 test suite. 
Jun 19 02:23:59 [----------] Global test environment set-up. 
Jun 19 02:23:59 [----------] 1 test from IntegrationTest 
Jun 19 02:23:59 [ RUN      ] IntegrationTest.CartPole 
Jun 19 02:23:59 -- Test timed out at 2020-06-19 02:23:44 UTC -- 

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 28 times.

@XiaobingSuper
Copy link
Collaborator Author

XiaobingSuper commented Mar 30, 2020

@ngimel , @VitalyFedyunin , this PR is about enable DNNL 3d ops, including conv, pooling and batchnorm, For resnext3d-101, and test on real dataset UCF101(input size is 10x3x32x128x170), we can get ~13x performance improvement compare to native cpu path on skx-8180. You can see the details in resnext3d-101. Thanks!

@vincentqb
Copy link
Contributor

@VitalyFedyunin -- could you review this PR?

@vincentqb vincentqb added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 30, 2020
@XiaobingSuper XiaobingSuper requested a review from ngimel March 31, 2020 07:12
@XiaobingSuper
Copy link
Collaborator Author

@VitalyFedyunin, please help review this code, thanks!

@yinghai
Copy link
Contributor

yinghai commented Apr 1, 2020

@lly-zero-one Could you comment on the perf side?

@lly-zero-one
Copy link
Contributor

We have few internal changes for the current Conv3d implementation for performance improvement, which will be upstreamed in next week. So I am wondering whether we could do a full performance benchmark. For 2d case, we found the mkldnn conv is 2x slower than the native implementation on a specific production model (I will file a repro).

@mingfeima
Copy link
Collaborator

We have few internal changes for the current Conv3d implementation for performance improvement, which will be upstreamed in next week. So I am wondering whether we could do a full performance benchmark. For 2d case, we found the mkldnn conv is 2x slower than the native implementation on a specific production model (I will file a repro).

2x performance diff with native implementation is serious... In future if you have similar issues, you may also address this in the Teams channel, we will get hands on it asap.
@Jianhui-Li, @uyongw, @jgong5

@lly-zero-one
Copy link
Contributor

#35937 is for tracking the issue.

Copy link
Contributor

@VitalyFedyunin VitalyFedyunin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is inconsistent with the approach we use for operators naming, we always explicitly specify 1d,2d,3d operators and we are letting python nn module to dispatch to the proper one.

Using this approach you are not only will follow convention, but also avoid introducing back incompatible changes.

This comment applies to all PRs in stack.

@XiaobingSuper
Copy link
Collaborator Author

@VitalyFedyunin

@XiaobingSuper XiaobingSuper requested a review from albanD May 14, 2020 01:15
@XiaobingSuper
Copy link
Collaborator Author

@pinzhenx pinzhenx mentioned this pull request Jun 16, 2020
facebook-github-bot pushed a commit that referenced this pull request Jun 16, 2020
Summary:
- Bump DNNL to 1.5
- Bug fixes and improvements in ideep
  - suppress g++ Wreorder warning
  - avoid rebuilding `libmkldnn.so` uxlfoundation/oneDNN#743
  - enable conv3d (integration code was checked in by Xiaobing #35662)
Pull Request resolved: #40088

Differential Revision: D22071530

Pulled By: albanD

fbshipit-source-id: e7a53d7421e8a7a03e36a7dfb68edc565a2f00df
@XiaobingSuper
Copy link
Collaborator Author

@ngimel, please help merge those PRs, thanks!

self.bias = state[1].to_mkldnn()
self.training = state[2]

class MkldnnConv3d(torch.jit.ScriptModule):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I believe this is kind of a legacy API as we now compile nn Modules recursively. Cc @eellison

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, correct, better to inherit from torch.nn.Module, you shouldn't need any other changes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will change it, and also for other case at next step.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison , if it is inherited from torch.nn.Module, there will has a problem for torch.jit.save method, because for a MKLDNN module, the parameters are MKLDNN tensor which are opaque tensors(do not have storage), we will first call .to_dense() at getstate to save this script module. I will changed until this problem can be sovled. thanks!


@torch.jit.script_method
def forward(self, x):
return torch.conv3d(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would happen here if parameters are not supported by mkldnn (use_mkldnn would return false e.g. because of dilation, or because x is wrong type), but weight is already reordered?
Also, suppose mkldnn_convolution is indeed called from Convolution.cpp, what happens next? In mkldnn_convolution there's only

  ideep::tensor mkldnn_output = _mkldnn_conv2d(
      mkldnn_input,
      mkldnn_weight,
      mkldnn_bias,
      padding,
      stride,
      dilation,
      groups);

If it is able to handle conv3d, then at the very least it is confusingly named.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first question, DNNL also support dilation for convNd, I can enable it, so there only has one case not supported by DNNL: x is not a float tensor, but for this case, I think weight is also a float tensor, it will report an error to user when reorder the weight, because it need to call **.to_mkldnn()**first which will check the tensor's type.

For the second question, yes, the name is confused, I will change it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, please enable dilated convolution then. I agree that for correct user inputs this situation should not happen, but even in case of incorrect user inputs (user had float weight when creating a module, but is sending double tensor, or cuda tensor, to forward) the error message should be clear and helpful, and I don't know what will happen in this case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DNNL dilation conv is enabled now. thanks!

return (input.is_mkldnn()) || // input is mkldnn Tensor
(input.options().backend() == at::Backend::CPU &&
input.scalar_type() == kFloat && // only on CPU Float Tensors
!is_dilated() && // doesn't support dilation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you also remove is_dilated check from here if it's actually supported by mkldnn?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There has another PR to do it, see #40220.

IntArrayRef dilation,
int64_t groups) {

auto stride_vec = expand_param_if_needed(stride, "stride", 3);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, why do you need to expand stride etc here, but don't need to do it for 2d conv? If it were not for these expansion calls, reorder_weight functions are exactly the same for 2d and 3d.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we don't need expand them, they have been expanded at

kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)

@ngimel
Copy link
Collaborator

ngimel commented Jun 18, 2020

Currently if I send tensor of the wrong type (e.g. double) to mkldnn convolution, I get an error

RuntimeError: tensor.scalar_type() == ScalarType::Float INTERNAL ASSERT FAILED at "../aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp":70, please report a bug to PyTorch. itensor_view_from_dense expects float tensor input

which is the wrong error type, thrown by TORCH_INTERNAL_ASSERT. Should be TORCH_CHECK instead.

@XiaobingSuper
Copy link
Collaborator Author

Currently if I send tensor of the wrong type (e.g. double) to mkldnn convolution, I get an error

RuntimeError: tensor.scalar_type() == ScalarType::Float INTERNAL ASSERT FAILED at "../aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp":70, please report a bug to PyTorch. itensor_view_from_dense expects float tensor input

which is the wrong error type, thrown by TORCH_INTERNAL_ASSERT. Should be TORCH_CHECK instead.

Changed to TORCH_CHECK now.

@XiaobingSuper XiaobingSuper requested a review from ngimel June 19, 2020 04:33
xwang233 pushed a commit to xwang233/pytorch that referenced this pull request Jun 20, 2020
Summary:
- Bump DNNL to 1.5
- Bug fixes and improvements in ideep
  - suppress g++ Wreorder warning
  - avoid rebuilding `libmkldnn.so` uxlfoundation/oneDNN#743
  - enable conv3d (integration code was checked in by Xiaobing pytorch#35662)
Pull Request resolved: pytorch#40088

Differential Revision: D22071530

Pulled By: albanD

fbshipit-source-id: e7a53d7421e8a7a03e36a7dfb68edc565a2f00df
@XiaobingSuper
Copy link
Collaborator Author

@VitalyFedyunin

@facebook-github-bot
Copy link
Contributor

@VitalyFedyunin merged this pull request in 6ba807c.

@VitalyFedyunin
Copy link
Contributor

hi @XiaobingSuper we had to revert this stack (see https://ezyang.github.io/pytorch-ci-hud/build/pytorch-master logs), could you please create new PRs

@facebook-github-bot facebook-github-bot deleted the gh/xiaobingsuper/9/head branch June 26, 2020 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.