Skip to content

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Jan 5, 2018

Implements torch.bgesv, a batched linear system of equations solver.

Adds bindings for MAGMA's gesv_batched function for CUDA.
For CPU, runs THLapack(gesv) in a for loop.

I decided to not build this into torch.gesv but if we want to I can change that.

cc @apaszke

Test Plan

New unit tests:

  • CPU test against torch.gesv and specifying outputs
  • GPU test for the same thing
  • autograd test for derivative

@pytorchbot
Copy link
Collaborator

@zou3519, thanks for your PR! We identified @zdevito to be a potential reviewer.

@apaszke
Copy link
Contributor

apaszke commented Jan 5, 2018

cc @fritzo

@fritzo
Copy link
Collaborator

fritzo commented Jan 5, 2018

cc @tbrx @dwd31415

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks great, thanks for wrapping this up so quickly! Only a few minor comments.

}

magma_queue_t magma_queue;
magma_queue_create_from_cuda(

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

THError("MAGMA bgesv (gesv_batched) : For batch number %lld: U(%d,%d) is zero, singular U.",
(long long)batch_count, info, info);
}
}

This comment was marked as off-topic.

THTensor_(resizeNd)(self, 3, size, stride);
THTensor_(copy)(self, src);
return self;
}

This comment was marked as off-topic.

x_exp, LU_exp = torch.gesv(b.squeeze(0), A.squeeze(0))
x, LU = torch.bgesv(b, A)
self.assertEqual(x, x_exp.unsqueeze(0))
self.assertEqual(LU, LU_exp.unsqueeze(0))

This comment was marked as off-topic.

- Refactored THCTensor_(newBatchedColumnMajor) and
  THTensor_(cloneBatchedColumnMajor)
- In THCTensor_(bgesv), move error checking to after freeing a lot of
  things
- Added a test to check bgesv against gesv (in a loop) for a batch of size 4
@colesbury
Copy link
Member

A few high-level comments:

The equivalent NumPy function is np.linalg.solve

  1. We should prefer one function that handles batched and non-batched computation.
  2. We should treat the left-most dimensions as part of the batch. For example, the 4-d tensor A x B x M x M should be treated as batch-size (A x B).
  3. It would be good to support broadcasting
  4. When adding new functions, consider adding them as native functions to ATen.

You don't have to block the PR on these suggestions, but if it's not too much extra work see if you can merge gesv and bgesv so that we don't unnecessarily expand the public API.

@zou3519
Copy link
Contributor Author

zou3519 commented Jan 11, 2018

Closing this in favor of #4612

@zou3519 zou3519 closed this Jan 11, 2018
zou3519 added a commit to zou3519/pytorch that referenced this pull request May 7, 2018
Fixes pytorch#3164
Picks up from pytorch#4502

I moved `gesv` to ATen.
Adds bindings for MAGMA's `gesv_batched` function for CUDA.
For CPU, runs `THLapack(gesv)` in a for loop.

The new function supports arbitrary batch dimensions (and broadcasting
of those dimensions). For example, the 4-d tensor `A x B x M x M` should
be treated as having batch-size `(A x B)`.

The overhead of creating the magma_queue_t is: ~350000 microseconds
the first time it's called and ~6 microseconds every time after that.
soumith pushed a commit that referenced this pull request May 8, 2018
* Add batched linear solver to torch.gesv()

Fixes #3164
Picks up from #4502

I moved `gesv` to ATen.
Adds bindings for MAGMA's `gesv_batched` function for CUDA.
For CPU, runs `THLapack(gesv)` in a for loop.

The new function supports arbitrary batch dimensions (and broadcasting
of those dimensions). For example, the 4-d tensor `A x B x M x M` should
be treated as having batch-size `(A x B)`.

The overhead of creating the magma_queue_t is: ~350000 microseconds
the first time it's called and ~6 microseconds every time after that.

* Tests and docs

* Address comments

* Address comments

* Rebase

* Address comments

* Fix rebase

* Addressed comments

* Address comments

* Address comments

* Addressed comments
weiyangfb pushed a commit to weiyangfb/pytorch that referenced this pull request Jun 11, 2018
* Add batched linear solver to torch.gesv()

Fixes pytorch#3164
Picks up from pytorch#4502

I moved `gesv` to ATen.
Adds bindings for MAGMA's `gesv_batched` function for CUDA.
For CPU, runs `THLapack(gesv)` in a for loop.

The new function supports arbitrary batch dimensions (and broadcasting
of those dimensions). For example, the 4-d tensor `A x B x M x M` should
be treated as having batch-size `(A x B)`.

The overhead of creating the magma_queue_t is: ~350000 microseconds
the first time it's called and ~6 microseconds every time after that.

* Tests and docs

* Address comments

* Address comments

* Rebase

* Address comments

* Fix rebase

* Addressed comments

* Address comments

* Address comments

* Addressed comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants