-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Batched linear system of equations solver (torch.bgesv) #4502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
cc @fritzo |
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks great, thanks for wrapping this up so quickly! Only a few minor comments.
| } | ||
|
|
||
| magma_queue_t magma_queue; | ||
| magma_queue_create_from_cuda( |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| THError("MAGMA bgesv (gesv_batched) : For batch number %lld: U(%d,%d) is zero, singular U.", | ||
| (long long)batch_count, info, info); | ||
| } | ||
| } |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| THTensor_(resizeNd)(self, 3, size, stride); | ||
| THTensor_(copy)(self, src); | ||
| return self; | ||
| } |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| x_exp, LU_exp = torch.gesv(b.squeeze(0), A.squeeze(0)) | ||
| x, LU = torch.bgesv(b, A) | ||
| self.assertEqual(x, x_exp.unsqueeze(0)) | ||
| self.assertEqual(LU, LU_exp.unsqueeze(0)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
- Refactored THCTensor_(newBatchedColumnMajor) and THTensor_(cloneBatchedColumnMajor) - In THCTensor_(bgesv), move error checking to after freeing a lot of things - Added a test to check bgesv against gesv (in a loop) for a batch of size 4
|
A few high-level comments: The equivalent NumPy function is np.linalg.solve
You don't have to block the PR on these suggestions, but if it's not too much extra work see if you can merge gesv and bgesv so that we don't unnecessarily expand the public API. |
|
Closing this in favor of #4612 |
Fixes pytorch#3164 Picks up from pytorch#4502 I moved `gesv` to ATen. Adds bindings for MAGMA's `gesv_batched` function for CUDA. For CPU, runs `THLapack(gesv)` in a for loop. The new function supports arbitrary batch dimensions (and broadcasting of those dimensions). For example, the 4-d tensor `A x B x M x M` should be treated as having batch-size `(A x B)`. The overhead of creating the magma_queue_t is: ~350000 microseconds the first time it's called and ~6 microseconds every time after that.
* Add batched linear solver to torch.gesv() Fixes #3164 Picks up from #4502 I moved `gesv` to ATen. Adds bindings for MAGMA's `gesv_batched` function for CUDA. For CPU, runs `THLapack(gesv)` in a for loop. The new function supports arbitrary batch dimensions (and broadcasting of those dimensions). For example, the 4-d tensor `A x B x M x M` should be treated as having batch-size `(A x B)`. The overhead of creating the magma_queue_t is: ~350000 microseconds the first time it's called and ~6 microseconds every time after that. * Tests and docs * Address comments * Address comments * Rebase * Address comments * Fix rebase * Addressed comments * Address comments * Address comments * Addressed comments
* Add batched linear solver to torch.gesv() Fixes pytorch#3164 Picks up from pytorch#4502 I moved `gesv` to ATen. Adds bindings for MAGMA's `gesv_batched` function for CUDA. For CPU, runs `THLapack(gesv)` in a for loop. The new function supports arbitrary batch dimensions (and broadcasting of those dimensions). For example, the 4-d tensor `A x B x M x M` should be treated as having batch-size `(A x B)`. The overhead of creating the magma_queue_t is: ~350000 microseconds the first time it's called and ~6 microseconds every time after that. * Tests and docs * Address comments * Address comments * Rebase * Address comments * Fix rebase * Addressed comments * Address comments * Address comments * Addressed comments
Implements torch.bgesv, a batched linear system of equations solver.
Adds bindings for MAGMA's gesv_batched function for CUDA.
For CPU, runs THLapack(gesv) in a for loop.
I decided to not build this into
torch.gesvbut if we want to I can change that.cc @apaszke
Test Plan
New unit tests:
torch.gesvand specifying outputs