-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Move some bmm/baddbmm to ATen #11292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Native CPU implementation: defer to matrix multiplication for small batches and parallelize over batch dimension for large batches. (More improvement might be achieved by calling MKL's batch gemm.) - Add bmm test for CUDA just to be sure. This is a partial fix for pytorch#10661, getting down to a factor ~5. Considerable overhead is incurred for the setup in einsum. It might be more efficient to eventually define an optimized contraction functions for arbitrary and several dimensions.
| auto s0 = self.accessor<scalar_t, 3>(); | ||
| auto m0 = mat2.accessor<scalar_t, 3>(); | ||
|
|
||
| #pragma omp parallel for if(bs > 100) // or small ks? |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| auto s0 = self.accessor<scalar_t, 3>(); | ||
| auto m0 = mat2.accessor<scalar_t, 3>(); | ||
|
|
||
| #pragma omp parallel for if(bs > 100) // or small ks? |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Hm. Broadcasting. I only remembered bmm didn't... :/ Sorry, I'll get that fixed. |
|
So the two build failures are
|
|
@t-vi it'd be good to also fold in changes from #11365 into this PR. cc: @mingfeima |
|
oh, I noticed that you already commented on that PR, my bad for not noticing, |
|
The TH (not THC) implementation is dead now, right? Can you delete it? |
| auto s2 = s1[i]; | ||
| for (int64_t j = 0; j < js; j++) { | ||
| scalar_t &r = r2[j]; | ||
| r = 0; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| auto s2 = s1[i]; | ||
| for (int64_t j = 0; j < js; j++) { | ||
| scalar_t &r = r2[j]; | ||
| r *= beta; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| auto s0 = self.accessor<scalar_t, 3>(); | ||
| auto m0 = mat2.accessor<scalar_t, 3>(); | ||
|
|
||
| #ifdef _OPENMP |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| } else { // split along batch dimension | ||
| if (is_bmm_out) { | ||
| for (int64_t b = 0; b < bs; b++) { | ||
| auto r = self_or_result.select(0, b); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
error is legit: Just delete |
|
Thank you for your comments! I will improve the PR based on your feedback.
|
Prefix is fine. The reason it's messing up is that it looks for |
|
I'd just name it as |
|
@t-vi @soumith fold #11365 would be better :) from the optimization point of view, the optimization can be done in two manners.
i actually did both. Performance gain from mkl batched gemm is only mediocre. anyway, at this moment it is fair enough to stick @t-vi a few suggestions on this PR
@t-vi one more question, what type of CPU are you using? |
MKL support by Ma Mingfei, @mingfeima, errors by myself
|
@mingfeima Thank you for the detailed comments!
I currently run this on my laptop (Thinkpad with Intel Core i7-5600U on Linux), I could move it to my GPU host (i5-7500). Update: So on the i7-5600U, the modified benchmark by fritzo gives me for plain #11365 . It seems that batched gemm only is advantageous for larger matrices, indeed with 2000 200x200 / 200x1 matrices, the torch's MKL bmm outperforms numpy (numpy is stock from Debian). For reference: |
|
@t-vi Sorry for the misleading comment. the problem size with The problem size with Mathematically, they are both Any information on typical First thought here is that
|
|
@mingfeima OK, I'll see to not benchmarking MKL on my hardware. In your opinion, on the Xeon hardware that is the optimization focus of MKL, should MKL batch gemm be unconditionally used or should it depend on matrix size? Edit: Thank you for your input! I'll use |
Refined the switch between batch-parallel naive and series of mm. Bypass MKL based on same criterion based on very rough benchmark. Add comment about the optimization.
|
Thank you all for your input, I hope you find the changes make good use of it! I don't check thoroughly enough whether the arguments are contiguous or transposed contiguous before calling MKL, I'll fix that. |
|
I'm not really sure about the failing CI test, but other than that, I think the PR should be ready for review. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: - Incorporates MKL addition by mingfeima Thank you! (but all errors are my own) - Native CPU implementation: defer to matrix multiplication for small batches and parallelize over batch dimension for large batches. - Add bmm test for CUDA just to be sure. This is a partial fix for #10661, getting down to a factor ~5. Considerable overhead is incurred for the setup in einsum. It might be more efficient to eventually define an optimized contraction functions for arbitrary and several dimensions. Pull Request resolved: pytorch/pytorch#11292 Differential Revision: D9784941 Pulled By: ezyang fbshipit-source-id: f6dded2c6f5e8f0461fb38f31f9a824992a58358
* master: (165 commits) Aibench for asr decoder Explicitly set locale on docs build. (pytorch#11595) Documentation for debugging JIT Fused weightnorm for ATen (pytorch#10842) Move Type, Tensor, TensorMethods to core. Add reminder % to the jit Fix reloading modules back into python (pytorch#11552) Add trigonometry functions to docs/source/onnx.rst Add EndToEndHybridModel CUDA tests (pytorch#11544) minor formatting error log (pytorch#11528) Warn that export+import module always load onto the CPU (pytorch#11485) caffe2::StorageImpl use at::DataPtr (pytorch#11282) Sync all libnccl soversions, not just libnccl.so.1 (pytorch#11575) Document BatchNorm and update default behavior (pytorch#11484) Typo fix in randomness.rst (pytorch#11571) Move some bmm/baddbmm to ATen (pytorch#11292) Make c10d test work on CPU only build (pytorch#11567) Clean up some C++ cruftiness in the script lexer. Allow setting deletion constant Make C10d support CPU only build (pytorch#11513) ...
small batches and parallelize over batch dimension for large
batches.
This is a partial fix for #10661, getting down to a factor ~5.
Considerable overhead is incurred for the setup in einsum. It might
be more efficient to eventually define an optimized contraction
functions for arbitrary and several dimensions.