Skip to content

Conversation

@t-vi
Copy link
Collaborator

@t-vi t-vi commented Sep 5, 2018

  • Incorporates MKL addition by @mingfeima Thank you! (but all errors are my own)
  • Native CPU implementation: defer to matrix multiplication for
    small batches and parallelize over batch dimension for large
    batches.
  • Add bmm test for CUDA just to be sure.

This is a partial fix for #10661, getting down to a factor ~5.
Considerable overhead is incurred for the setup in einsum. It might
be more efficient to eventually define an optimized contraction
functions for arbitrary and several dimensions.

- Native CPU implementation: defer to matrix multiplication for
  small batches and parallelize over batch dimension for large
  batches.
  (More improvement might be achieved by calling MKL's batch gemm.)
- Add bmm test for CUDA just to be sure.

This is a partial fix for pytorch#10661, getting down to a factor ~5.
Considerable overhead is incurred for the setup in einsum. It might
be more efficient to eventually define an optimized contraction
functions for arbitrary and several dimensions.
auto s0 = self.accessor<scalar_t, 3>();
auto m0 = mat2.accessor<scalar_t, 3>();

#pragma omp parallel for if(bs > 100) // or small ks?

This comment was marked as off-topic.

This comment was marked as off-topic.

auto s0 = self.accessor<scalar_t, 3>();
auto m0 = mat2.accessor<scalar_t, 3>();

#pragma omp parallel for if(bs > 100) // or small ks?

This comment was marked as off-topic.

@t-vi
Copy link
Collaborator Author

t-vi commented Sep 6, 2018

Hm. Broadcasting. I only remembered bmm didn't... :/ Sorry, I'll get that fixed.

@t-vi
Copy link
Collaborator Author

t-vi commented Sep 7, 2018

So the two build failures are

  • FAIL: test_scalar_fusion (__main__.TestScript) on pr/pytorch-linux-trusty-py2.7. This seems to happen to several, but not all builds on Jenkins. I don't know why, but could the test be flaky?
  • On pr/caffe2-py2-mkl-ubuntu16.04-test there is a timeout for a distributed operation.

@ssnl
Copy link
Collaborator

ssnl commented Sep 7, 2018

test_scalar_fusion is indeed flaky (see #11360)
The caffe2 test is also flaky :/ (see #8982)

@soumith
Copy link
Contributor

soumith commented Sep 7, 2018

@t-vi it'd be good to also fold in changes from #11365 into this PR. cc: @mingfeima

@soumith
Copy link
Contributor

soumith commented Sep 7, 2018

oh, I noticed that you already commented on that PR, my bad for not noticing,

@ezyang
Copy link
Contributor

ezyang commented Sep 9, 2018

The TH (not THC) implementation is dead now, right? Can you delete it?

auto s2 = s1[i];
for (int64_t j = 0; j < js; j++) {
scalar_t &r = r2[j];
r = 0;

This comment was marked as off-topic.

auto s2 = s1[i];
for (int64_t j = 0; j < js; j++) {
scalar_t &r = r2[j];
r *= beta;

This comment was marked as off-topic.

auto s0 = self.accessor<scalar_t, 3>();
auto m0 = mat2.accessor<scalar_t, 3>();

#ifdef _OPENMP

This comment was marked as off-topic.

} else { // split along batch dimension
if (is_bmm_out) {
for (int64_t b = 0; b < bs; b++) {
auto r = self_or_result.select(0, b);

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Sep 9, 2018

error is legit:

14:38:48 RuntimeError: Function '_th_bmm' starts with a single underscore and is configured to have a method on Tensor. Functions that start with  a single underscore should only be functions in the at:: namespace and not methods on Tensor!

Just delete variant: method from the cwrap

@t-vi
Copy link
Collaborator Author

t-vi commented Sep 9, 2018

Thank you for your comments! I will improve the PR based on your feedback.
So I tried to adapt @mingfeima 's patch #11365 to have a baddbmm__mkl (baddbmm_) but it is difficult:

  • The _mkl-suffix messes up the type translations (non-constness of self) should I just adapt the types or would a prefix be OK?,
  • @fritzo 's benchmark from the bug report seems to indicate that it is slow when I compile with mkl. :(
    I will try compare with the original use batched gemm from mkl on torch.bmm when mkl is available #11365 PR to see whether that these are my adaptations or whether MKL is indeed slow in that case.

@ezyang
Copy link
Contributor

ezyang commented Sep 9, 2018

The _mkl-suffix messes up the type translations (non-constness of self) should I just adapt the types or would a prefix be OK?,

Prefix is fine. The reason it's messing up is that it looks for _out or _ as a suffix to decide what the types should be. So just don't put mkl at the very end and you'll be fine.

@ssnl
Copy link
Collaborator

ssnl commented Sep 9, 2018

I'd just name it as baddbmm_mkl_

@mingfeima
Copy link
Collaborator

mingfeima commented Sep 10, 2018

@t-vi @soumith fold #11365 would be better :)
sorry, i wasn't looking at baddbmm at the beginning, it's better to have it optimized with bmm.
i have been looking at opennmt, bmm is used quite a lot in global attention computation.

from the optimization point of view, the optimization can be done in two manners.

  1. mkl batched gemm: as use batched gemm from mkl on torch.bmm when mkl is available #11365
  2. OpenMP + sequential gemm: do OpenMP parallelization across batch_size dimension and call sequential MKL gemm.

i actually did both. Performance gain from mkl batched gemm is only mediocre. openmp + sequential gemm actually performs much better than mkl batched gemm on unit benchmark. But once tested from PyTorch, it is not that good, needs further refinement with pytorch threading model here.

anyway, at this moment it is fair enough to stick mkl batched gemm for this PR. I checked implementations from other major frameworks, no better approach found.

@t-vi a few suggestions on this PR

  1. use mkl to perform gemm. gemm is actually quite a complex topic, you have to deal with blocking, vectorization, etc.
  2. i suppose you don't have to explicitly use #pragma omp parallel for now. Even if you need, use parallel_for wrapper from /aten/src/ATen/Parallel.h would be probably a better idea.
  3. aten/src/ATen/native/LinearAlgebra.cpp is kind of like a unified entrance for all backends. So put mkl related functions under aten/src/ATen/native/mkl/ would be better. And you can guard compilation with AT_MKL_ENABLED() and guard runtime selection with at::hasMKL(). Similarly, in case i figure out how to fix openmp + sequential gemm issue, it should go to aten/src/ATen/native/cpu.
  4. baddbmm and bmm can share the same underlying mkl kernels. bmm is a specialization of baddbmm when alpha = 1 and beta = 0.

@t-vi one more question, what type of CPU are you using?

MKL support by Ma Mingfei, @mingfeima, errors by myself
@t-vi
Copy link
Collaborator Author

t-vi commented Sep 10, 2018

@mingfeima Thank you for the detailed comments!
I think I need some help more detail to follow your suggestions:

  1. Use mkl to perform gemm means that I use your batched gemm? I have added that locally, but the performance seems bad compared to spelling out the matrix multiplication when I have very small matrices (as in fritzo's benchmark).
  2. for non-mkl, I'm not sure how to avoid parallel for (though using Parallel.h is much better than trying to do this myself, thanks for pointing that out!).
  3. my impression was that CPU is for specialized, accellerated (like AVX) code and "generic" CPU code is in native (e.g. Embedding, ...)
  4. I think I have that locally. :)

I currently run this on my laptop (Thinkpad with Intel Core i7-5600U on Linux), I could move it to my GPU host (i5-7500).

Update:

So on the i7-5600U, the modified benchmark by fritzo gives me

torch: 32.022253560018726
numpy: 0.017791966965887696
torch.bmm: 31.382055553025566

for plain #11365 . It seems that batched gemm only is advantageous for larger matrices, indeed with 2000 200x200 / 200x1 matrices, the torch's MKL bmm outperforms numpy (numpy is stock from Debian).

For reference:

print (torch.__version__)
x = torch.randn((2, 2000))
y = torch.randn((2, 2, 2000))
xp = x.permute(1, 0).view(2000, 1, 2)
yp = y.permute(2, 0, 1)
equation = 'ac,abc->cb'

time0 = timeit.default_timer()
for _ in range(1000):
    _ = torch.einsum(equation, [x, y])

time1 = timeit.default_timer()
for _ in range(1000):
    _ = numpy.einsum(equation, x.numpy(), y.numpy())

time2 = timeit.default_timer()

for _ in range(1000):
    _ = torch.bmm(xp, yp)
time3 = timeit.default_timer()

print((torch.einsum(equation, [x, y])-torch.bmm(xp, yp).squeeze()).abs().max().item())

print('torch: {}'.format(time1 - time0))
print('numpy: {}'.format(time2 - time1))
print('torch.bmm: {}'.format(time3 - time2))

@mingfeima
Copy link
Collaborator

mingfeima commented Sep 10, 2018

@t-vi perhaps you will get different result on Xeon.
I test benchmarks on server CPUs, a.k.a. Xeon. And the current most highend model is skylake 8180 with 56 cores. Desktop CPU is actually not the major optimization target of Intel. I suggest you try Xeon, they are available on clouds, e.g. AWS, etc.

Sorry for the misleading comment. the problem size with einsum here is much smaller than my case.
my problem size is batch_size around 64 and gemm size around 1*50*800. This type of scenario is perfect for batched gemm.

The problem size with einsum is batch_size 2000 and gemm 1*2*2, using batch gemm will not help for it, 1*2*2 should not go to gemm at all, this is too small. Using single thread vectorized code should have the best performance. And this is not related to CPU model, i see similar result on Xeon as well.

Mathematically, they are both bmm, but from optimization point of view, they are actually two diverse problems.

Any information on typical einsum problem size? Is the gemm always 1*2*2? Need to figure out how can we come up with a better solution.

First thought here is that

  1. put a threshold on MNK. For example, if MNK > 16*16*16, go mkl batched gemm, otherwise go the manually written one as proposed by @t-vi
  2. as for the manually written one, the threshold for OpenMP parallelization needs to be tuned a little bit. From my experience, good candidates for the threshold is probably 5000~20000. Because computation within each thread is really small, so you will only see performance gain with OpenMP parallelization only if you have enough works.
  3. if this can't get satisfied performance for einsum, perhaps need to consider to isolate einsum from bmm and use specific vectorized CPU kernels.

@t-vi
Copy link
Collaborator Author

t-vi commented Sep 10, 2018

@mingfeima OK, I'll see to not benchmarking MKL on my hardware.

In your opinion, on the Xeon hardware that is the optimization focus of MKL, should MKL batch gemm be unconditionally used or should it depend on matrix size?

Edit: Thank you for your input! I'll use m*k*n < 400 for now (a factor ~10 off your suggestion, but I hope it is OK and I MKL seems to work great for that). I also changed to parallel_for and and specified grain size with MIN_GRAIN/(m*k*n) rather than meddling with OMP myself.

Refined the switch between batch-parallel naive and
series of mm.
Bypass MKL based on same criterion based on very rough
benchmark.
Add comment about the optimization.
@t-vi
Copy link
Collaborator Author

t-vi commented Sep 10, 2018

Thank you all for your input, I hope you find the changes make good use of it!

I don't check thoroughly enough whether the arguments are contiguous or transposed contiguous before calling MKL, I'll fix that.

@t-vi
Copy link
Collaborator Author

t-vi commented Sep 11, 2018

I'm not really sure about the failing CI test, but other than that, I think the PR should be ready for review.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Sep 12, 2018
Summary:
- Incorporates MKL addition by mingfeima  Thank you! (but all errors are my own)
- Native CPU implementation: defer to matrix multiplication for
  small batches and parallelize over batch dimension for large
  batches.
- Add bmm test for CUDA just to be sure.

This is a partial fix for #10661, getting down to a factor ~5.
Considerable overhead is incurred for the setup in einsum. It might
be more efficient to eventually define an optimized contraction
functions for arbitrary and several dimensions.
Pull Request resolved: pytorch/pytorch#11292

Differential Revision: D9784941

Pulled By: ezyang

fbshipit-source-id: f6dded2c6f5e8f0461fb38f31f9a824992a58358
petrex pushed a commit to petrex/pytorch that referenced this pull request Sep 12, 2018
* master: (165 commits)
  Aibench for asr decoder
  Explicitly set locale on docs build. (pytorch#11595)
  Documentation for debugging JIT
  Fused weightnorm for ATen (pytorch#10842)
  Move Type, Tensor, TensorMethods to core.
  Add reminder % to the jit
  Fix reloading modules back into python (pytorch#11552)
  Add trigonometry functions to docs/source/onnx.rst
  Add EndToEndHybridModel CUDA tests (pytorch#11544)
  minor formatting error log (pytorch#11528)
  Warn that export+import module always load onto the CPU (pytorch#11485)
  caffe2::StorageImpl use at::DataPtr (pytorch#11282)
  Sync all libnccl soversions, not just libnccl.so.1 (pytorch#11575)
  Document BatchNorm and update default behavior (pytorch#11484)
  Typo fix in randomness.rst (pytorch#11571)
  Move some bmm/baddbmm to ATen (pytorch#11292)
  Make c10d test work on CPU only build (pytorch#11567)
  Clean up some C++ cruftiness in the script lexer.
  Allow setting deletion constant
  Make C10d support CPU only build (pytorch#11513)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants