-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Logcumsumexp for CUDA (build-time optimized) #94310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94310
Note: Links to docs will display an error until the docs builds have been completed. ❌ 12 FailuresAs of commit c05db6c: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Ah, actually scratch the constexpr comments, seems like there are some implementation issues in custom scalar types. |
|
|
||
| // custom min and max to be used in logcumsumexp for complex arguments | ||
| template <typename scalar_t, bool min> | ||
| __host__ __device__ c10::complex<scalar_t> _logcumsumexp_minmax(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually you can revert the templating arg too, its' a bit difficult to setup this in a constexpr if statement that is clean with the all the non-constexpr conditions as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also all the else statements are unnecessary since they all have return statements in them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What difference does it make if we remove the else statements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mfkasim1 just removes extra indentation. That's why it's a nit. Don't really care either way.
| // handling the "infectious" NaNs | ||
| return {std::numeric_limits<scalar_t>::quiet_NaN(), std::numeric_limits<scalar_t>::quiet_NaN()}; | ||
| } | ||
| else if ((!::isfinite(min_real)) && (min_real == max_real)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit but a lot of the elses also aren't needed here due since it's all just dealing with early returns
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
malfet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's wait for binary build results, but otherwise looks good to me
|
@mfkasim1 Looks good to me. Feel free to trigger the merge whenever. |
|
Thanks @malfet @Skylion007 @pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again. You can rebase by leaving the following comment on this PR: Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot rebase |
|
@pytorchbot successfully started a rebase job. Check the current status here |
|
Successfully rebased |
f192b9c to
c05db6c
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: windows-binary-wheel / wheel-py3_9-cpu-test Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f 'Unrelated infra issue. Broken smoketest label binaries' |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hopefully fixes #89205.
This is another version of #90847 where it was reverted because it increases the compile-time significantly.
From my discussion with @ngimel in #93153 (comment), it seems the option of jiterator would be very tricky if not impossible.
So what I did was to optimize the compile-time in my computer.
To optimize the build time, first I compile the pytorch as a whole, then only change the
LogcumsumexpKernel.cufile to see how it changes the compile time.Here are my results for the compilation time of only the
LogcumsumexpKernel.cufile in my computer:If the previous PR increases the build time by 30 mins in pytorch's computer, then this PR reduces the increment of build time to about 6 mins. Hopefully this is an acceptable level of build-time increase.
What I did was (sorted by how significant it reduces the build time from the most significant one):
log(x)tolog1p(x - 1). This is applied in the infinite case, so we don't really care about precision.tag: @malfet, @albanD