Skip to content

Conversation

@pinzhenx
Copy link
Collaborator

@pinzhenx pinzhenx commented Dec 5, 2020

Fixes #35937

Summary

This PR aims to optimize the default CPU path of Convolution with MKLDNN kernel. Earlier, users found that MKLDNN Conv underperformed THNN Conv in some circumstances (Issue#35937, PR-40610, PR-46675), especially when kernel size is equal to one or kernel is significantly larger than the input, MKLDNN kernel could be 2x slower than THNN.

Now we've improved the heuristics of the Conv algorithm selection and cut the overhead in some kernels, achieving the same or better performance than THNN.

Benchmark

Unit Tests

OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 python <(wget -o /dev/null -qO- https://gist.githubusercontent.com/pinzhenx/8f62d5076bb04f0fd2108380b22dfbaa/raw/afc42165c5103e686a3c489198f8658a9a03a258/benchmark_conv.py)

Shapes in this script (https://gist.github.com/pinzhenx/8f62d5076bb04f0fd2108380b22dfbaa) are collected from issues and PRs mentioned above. All the problematic cases now have been fixed by this PR.

                  src                  wei        str        pad   g      mkldnn        thnn result
    [1, 1024, 14, 14]   [2048, 1024, 1, 1]     [2, 2]     [0, 0]   1    2.205463    1.850814   slow (1)
     [1, 512, 28, 28]     [512, 512, 3, 3]     [2, 2]     [1, 1]  32    0.457784    4.703026   fast
     [1, 256, 56, 56]     [256, 256, 3, 3]     [2, 2]     [1, 1]  32    0.780760    6.377566   fast
     [1, 256, 56, 56]     [256, 256, 1, 1]     [1, 1]     [0, 0]   1    2.154216    2.243744   fast
     [1, 128, 56, 56]     [256, 128, 1, 1]     [1, 1]     [0, 0]   1    1.184853    1.222208   fast
     [1, 256, 56, 56]     [512, 256, 1, 1]     [2, 2]     [0, 0]   1    1.348624    1.457991   fast
     [1, 256, 56, 56]     [128, 256, 1, 1]     [1, 1]     [0, 0]   1    1.099386    1.120938   fast
      [1, 1024, 7, 7]   [2048, 1024, 1, 1]     [1, 1]     [0, 0]   1    1.731668    1.756143   fast
      [1, 2048, 7, 7]   [1024, 2048, 1, 1]     [1, 1]     [0, 0]   1    1.749803    1.768455   fast
    [1, 1024, 14, 14]   [1024, 1024, 3, 3]     [2, 2]     [1, 1]  32    0.506553    3.718218   fast
    [1, 1024, 14, 14]    [512, 1024, 1, 1]     [1, 1]     [0, 0]   1    1.289359    1.333481   fast
     [1, 256, 28, 28]     [256, 256, 3, 3]     [1, 1]     [1, 1]  32    0.502607    3.105249   fast
     [1, 3, 224, 224]        [64, 3, 7, 7]     [2, 2]     [3, 3]   1    1.861805    6.059177   fast
     [1, 128, 56, 56]     [128, 128, 3, 3]     [1, 1]     [1, 1]  32    1.043156    4.013101   fast
      [1, 1024, 7, 7]   [1024, 1024, 3, 3]     [1, 1]     [1, 1]  32    0.412053    3.090565   fast
     [1, 512, 28, 28]     [512, 512, 1, 1]     [1, 1]     [0, 0]   1    2.102176    2.191270   fast
     [1, 512, 28, 28]     [256, 512, 1, 1]     [1, 1]     [0, 0]   1    1.073232    1.151083   fast
     [1, 256, 28, 28]     [512, 256, 1, 1]     [1, 1]     [0, 0]   1    1.099426    1.160942   fast
     [1, 512, 28, 28]    [1024, 512, 1, 1]     [2, 2]     [0, 0]   1    1.408962    1.485216   fast
      [1, 64, 56, 56]      [128, 64, 1, 1]     [1, 1]     [0, 0]   1    0.361198    0.381369   fast
      [1, 64, 56, 56]      [256, 64, 1, 1]     [1, 1]     [0, 0]   1    0.683444    0.851324   fast
     [1, 512, 14, 14]     [512, 512, 3, 3]     [1, 1]     [1, 1]  32    0.320881    3.007333   fast
     [1, 512, 14, 14]    [1024, 512, 1, 1]     [1, 1]     [0, 0]   1    1.283734    1.314242   fast
    [1, 1024, 14, 14]   [1024, 1024, 1, 1]     [1, 1]     [0, 0]   1    2.543330    2.556091   fast
       [1, 512, 4, 4]     [512, 512, 3, 3]     [1, 1]     [1, 1]   1    0.845248    0.966355   fast
     [25, 3, 48, 320]        [64, 3, 7, 7]     [1, 1]     [0, 0]   1  117.504581  135.730401   fast
     [1, 3, 384, 288]        [64, 3, 7, 7]     [1, 1]     [0, 0]   1   26.538442   44.603726   fast
 [1, 3, 16, 224, 224]     [32, 3, 1, 7, 7]  [1, 1, 1]  [0, 0, 0]   1  137.520870  381.731755   fast
  [1, 3, 4, 112, 112]     [64, 3, 3, 7, 7]  [1, 1, 1]  [0, 0, 0]   1    7.363028   14.706053   fast
  [1, 256, 8, 14, 14]  [256, 256, 3, 3, 3]  [1, 1, 1]  [0, 0, 0]   1   19.072787   21.632466   fast

(1) This "outlier" is due to a smaller memory footprint used by MKLDNN at the beginning.
    If we rerun this shape somewhere in the middle again, we could get a faster result.
(2) Tested on Skylake 8180 with AVX-512 supported
(3) Timings are in ms

Model Tests

As for the models, we tested two variations of resnext and got a comparable result as before.

 (ms) all MKLDNN Conv (after patch) MKLDNN + THNN Conv (before patch)
resnext101 32x8d 266.99 269.28
resnext50 32x4d 83.09 85.515

Config: Skylake 8180, batch=1, thread=1, jemalloc



@lly-zero-one @bertmaher @dzhulgakov @ngimel @CaoZhongZ @jgong5

@dr-ci
Copy link

dr-ci bot commented Dec 5, 2020

🔗 Helpful links

💊 CI failures summary and remediations

As of commit eeac5ef (more details on the Dr. CI page):


Commit eeac5ef was recently pushed. Waiting for builds...


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@ngimel
Copy link
Collaborator

ngimel commented Dec 6, 2020

1x1 convolutions directly use gemm calls now in pytorch, w/o any overhead. What is mkldnn doing that's better than mkl? Should you guys fix mkl for those cases?

@pinzhenx
Copy link
Collaborator Author

pinzhenx commented Dec 6, 2020

@ngimel For 1x1 conv, we also directly call the sgemm rewritten by MKLDNN. So it's reasonable that we have a small margin over MKL-based THNN conv

@CaoZhongZ
Copy link
Contributor

1x1 convolutions directly use gemm calls now in pytorch, w/o any overhead. What is mkldnn doing that's better than mkl? Should you guys fix mkl for those cases?

We removed unnecessary transforms in 1x1 convolution and release the performance as it was supposed to be. The difference between MKL SGEMM and oneDNN should be trivial, the only difference is we tested and targeted oneDNN for DNN usage cases. We could look into why oneDNN is faster than MKL call in THNN.

@bertmaher bertmaher self-requested a review December 7, 2020 17:28
@ailzhang ailzhang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 7, 2020
@gottbrath
Copy link
Contributor

Comment from Jan 27th discussion:

  • Please benchmark & test on non AVX512 use cases (i.e. AVX2)
  • please explicitly state which version of MKLDNN is required.
  • Commit benchmark scripts as part of this PR. (@VitalyFedyunin to follow up with instructions)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MKLDNN_conv2d 2X slower than the native TH implementation

7 participants