-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Performance and memory improvements to batched torch.linalg.solve (2nd attempt) #71756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Previously for single input matrix A and batched matrix B, the matrix A was expanded and cloned before computing the LU decomposition and solving the linear system. With this PR the LU decomposition is computed once for single matrix and then expanded&cloned if required by backend library call.
Implement broadcast linear indexing for lu and pivots.
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 262fe91 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
| Job | Step | Action |
|---|---|---|
| Unknown | 🔁 rerun |
This comment was automatically generated by Dr. CI (expand for details).
Please report bugs/suggestions to the (internal) Dr. CI Users group.
lezcano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments, but none of them needs to be addressed.
| return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous(); | ||
| } | ||
|
|
||
| class BroadcastLinearIndices { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this could be used in cusolver's gesvd, as we do a very similar thing there. Not for this PR ofc.
| auto b_stride = matrixStride(b); | ||
| auto lu_stride = matrixStride(lu); | ||
| auto pivots_stride = pivots_cpu.size(-1); | ||
| auto lu_stride = lu.dim() > 2 ? lu.stride(-3) : 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like this way to get the "matrix stride". This is a great idea!
| TORCH_INTERNAL_ASSERT(batchCount(b) == batchCount(lu), "batch_size of b and lu must be the same"); | ||
| TORCH_INTERNAL_ASSERT(batchCount(lu) == batchCount(pivots.unsqueeze(-1)), "batch_size of lu and pivots must be the same"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just another note. I just realised that we can implement a faster batchCount doing t.dim() > 2 ? t.numel() / t.stride(-3) : 1 when we have a batch of column or row-major matrices.
| c10::MaybeOwned<Tensor> maybe_expand_lu(const Tensor& b, const Tensor& lu) { | ||
| if (batchCount(b) != batchCount(lu)) { | ||
| IntArrayRef b_batch_size(b.sizes().data(), b.dim() - 2); | ||
| std::vector<int64_t> expand_size = b_batch_size.vec(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit. DimVector is better suited for these use cases.
|
CI fails with AssertionError: False is not true : Tensors failed to compare as equal!With rtol=1e-05 and atol=1e-05, found 241 element(s) (out of 256) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.6182836257459539 (0.9761708184203711 vs. 0.3578871926744173), which occurred at index (5, 5).Tests pass locally. Maybe it's just bad seed + lstsq test being flaky? |
|
This PR does not touch |
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stamped! Thank you @IvanYashchuk
|
@mruberry has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@IvanYashchuk Do you want to address @lezcano's comments and use it as an opportunity to retrigger CI and see if the lstsq error persists? If it does not we should file an issue for it being a flaky test. |
|
I replaced Let's see whether that lstsq test passes now. |
|
This PR has accumulated conflicts and I need to resolve them. |
|
Hey @mruberry, could you please try importing this PR? |
|
Hi @mruberry. Could you import this one? As part of the LU stack that I'm working on, I plan to go and use all the new tools we have to optimise |
|
@mruberry has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…d attempt) (#71756) Summary: Previous PR with the same content: #69752. Opening a new PR by request: #69752 (comment). ------ Previously for single input matrix A and batched matrix B, matrix A was expanded and cloned before computing the LU decomposition and solving the linear system. With this PR the LU decomposition is computed once for a single matrix and then expanded&cloned if required by a backend library call for the linear system solving. Here's a basic comparison: ```python # BEFORE THE PR In [1]: import torch In [2]: a = torch.randn(256, 256) In [3]: b = torch.randn(1024, 256, 2) In [4]: %%timeit ...: torch.linalg.solve(a, b) ...: ...: 329 ms ± 17.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) # WITH THIS PR In [1]: import torch In [2]: a = torch.randn(256, 256) In [3]: b = torch.randn(1024, 256, 2) In [4]: %%timeit ...: torch.linalg.solve(a, b) ...: ...: 21.4 ms ± 23 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` Fixes #71406, fixes #71610 Pull Request resolved: #71756 Reviewed By: ngimel Differential Revision: D33771981 Pulled By: mruberry fbshipit-source-id: 0917ee36a3eb622ff75d54787b1bffe26b41cb4a
|
Hey @IvanYashchuk. |
Previous PR with the same content: #69752. Opening a new PR by request: #69752 (comment).
Previously for single input matrix A and batched matrix B, matrix A was expanded and cloned before computing the LU decomposition and solving the linear system.
With this PR the LU decomposition is computed once for a single matrix and then expanded&cloned if required by a backend library call for the linear system solving.
Here's a basic comparison:
Fixes #71406, fixes #71610