-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Batch potrf #11796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch potrf #11796
Conversation
I didn't know how getri worked after using getrf, which led to some issues
Native batched potrf native. One item of pytorch#7500 Also batch tril/triu in native. This doesn't speparate out the single matrix versions, but I don't think it is too inefficient. This builds on the batch linear algebra systematic of pytorch#9949 by @vishwakftw .
|
With apologies to @ethanluoyc for not being aware of #9623 before just the backward and tests were missing. |
|
Hmm. This looks like something in #9949 (@vishwakftw ). |
|
#9949 requires a rebase, yes. |
| auto m = self.size(-1); | ||
| auto self_batched_ = self.view({-1, n, m}); | ||
| auto self_batched = self_batched_.accessor<scalar_t, 3>(); | ||
| auto result_batched_ = result.view({-1, n, m}); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Sure thing. Where should I put the code? @vishwakftw moved the Gesv.* to BatchLinearAlgebra.* in his batch inverse PR (#9949) and I put my stuff in there. Should I keep the renaming? |
|
Is the plan to add a batched version of potrs as well? Doesn't seem that's part of this PR. |
|
**Edit**: I mixed that up with pstrf, which isn't on GPU as far as I know, sorry. Thanks vishwakftw for correcting me!
My other priority is "pure GPU potrs".
|
|
|
|
@t-vi could you rebase? |
|
I looked into this, rebasing seems more effort than starting over. |
| n = n % stride_batch; | ||
| } else if (stride_batch > stride_min) { | ||
| n = n - n % stride_max + n % stride_batch; // eliminate batch part | ||
| } // if stride_batch < stride min, the divisions below will eliminate batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@t-vi Could you please give a small explanation of the logic here? It doesn't seem very obvious to me, unfortunately. Thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, you have stride_max > stride_batch > stride_min you now want to eliminate the batch contribution in the offset. Subtracting n % stride_max will eliminate both the batch and the "lower" contribution, so adding n % stride_batch adds the second part back. As a result exactly the batch offset is removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got it. Thank you very much.
Native batched potrf native. One item of #7500
Also batch tril/triu in native.
This doesn't speparate out the single matrix versions, but I don't
think it is too inefficient.
This builds on the batch linear algebra systematic of #9949 by @vishwakftw