-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add CPU/CUDA support to torch.logcumsumexp #90847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90847
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 FailuresAs of commit d28c203: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@albanD could you please help me how to enable this function to support automatic differentiation for complex dtype? |
|
Sure! pytorch/tools/autograd/gen_variable_type.py Line 171 in 0b255b3
This is still an allow-list as most ops are not supported and we want to be careful when new ops are added as it is easy to forget to test the complex case. |
|
The errors seems to be precision issues with the gradients that you compute right? |
|
@albanD yes, but it seems there are other unrelated errors which I don't know how to get rid of. |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks quite good! Only small comments.
test/test_reductions.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might want to define this as a ref for the OpInfo and it will check that the ref match for all simple inputs without the need for this custom code. Or there are more cases tested here that cannot be tested via the OpInfo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found several challenges in using OpInfo (I probably don't understand OpInfo enough to fully utliize it) for this complex logcumsumexp function:
- To compare 2 outputs of
logcumsumexpfunctions, the imaginary part needs to be standardized to be within(0, 2 pi)or(-pi, pi). This is becauselog(r*e^{i t}) = log|r| + i * (t + 2 pi * n). - scipy.logsumexp gives some confusing answers involving
infcase (see my comment around line 590 of this file)
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update. Sounds good!
|
Thanks for the approve, @albanD! However, the last time the windows instance run, it still produces an error on the test. I can't reproduce this error on my machine and somehow the error is only raised on windows instance (not other instance). It's related to an edge case involving |
|
math libraries are often different on windows and produce different results in edge cases, I think you can just skip windows test (we have quite a few such skips already) |
|
Thanks, @ngimel! I'll push an update that skips the test for windows. |
|
I think I'm done with the inf-nan bug on windows. If any of you would like to comment on anything on this PR, please let me know. |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me
|
What's next for this? |
|
Ho sorry, once the PR is approved, you can ask the bot to merge it yourself :) @pytorchbot -h |
PyTorchBot HelpMergeRevertRebaseLabelDr CI |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again. You can rebase by leaving the following comment on this PR: Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot rebase |
|
@pytorchbot successfully started a rebase job. Check the current status here |
|
Successfully rebased |
f1b5289 to
d28c203
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 3 jobs have failed, first few of them are: linux-binary-libtorch-cxx11-abi / libtorch-cpu-shared-with-deps-cxx11-abi-build / build, trunk / macos-12-py3-arm64 / test (functorch, 1, 1, macos-m1-12), linux-binary-libtorch-pre-cxx11 / libtorch-cpu-shared-with-deps-pre-cxx11-build / build Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f "unrelated error" |
|
You are not authorized to force merges to this repository. Please use the regular |
|
@pytorchbot merge -f "Flaky CI" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
This PR increased nightly build time by 30 min, trying a few remedies. |
|
We also observe this code timing out our build |
|
@mfkasim1 is performance of this operation important? Such increase in build time for a niche op is not great. |
|
@pytorchbot revert -m "Reverting to decrease build time, let's discuss the alternatives here" -c weird |
|
@pytorchbot successfully started a revert job. Check the current status here. |
|
@mfkasim1 your PR has been successfully reverted. |
This reverts commit 6498512. Reverted #90847 on behalf of https://github.com/malfet due to Reverting to decrease build time, let's discuss the alternatives here
Yes, this is the bottleneck operation in my research. |
Can you give some quantifiable numbers? I.e. how bad the perf would be if it's a composite op? Also, can we make it jiteratable for just complex numbers? Or keep out of core? |
|
I tried implementing it with real def _logcumsumexp(z: torch.Tensor, dim: int) -> torch.Tensor:
# complex type
q = torch.real(z)
k = torch.imag(z)
a = _logcumsum_aexp(torch.cos(k), q, dim=dim)
b = _logcumsum_aexp(torch.sin(k), q, dim=dim)
c = _log_add_exp(a, b + 0.5j * np.pi)
return c
def _logcumsum_aexp(a: torch.Tensor, b: torch.Tensor, dim: int) -> torch.Tensor:
# log(cumsum(a * exp(b))), a & b are real, but the returned values are complex
log_a_pos = torch.log(torch.clamp(a, min=torch.finfo(a.dtype).tiny))
log_a_neg = torch.log(torch.clamp(-a, min=torch.finfo(a.dtype).tiny))
lcse_pos = torch.logcumsumexp(b + log_a_pos, dim=dim)
lcse_neg = torch.logcumsumexp(b + log_a_neg, dim=dim)
return _log_add_exp(lcse_pos, lcse_neg + 1j * np.pi)
def _log_add_exp(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# complex x & y
xr = torch.real(x)
xi = torch.imag(x) if torch.is_complex(x) else torch.zeros_like(x)
yr = torch.real(y)
yi = torch.imag(y) if torch.is_complex(y) else torch.zeros_like(y)
x_greater = xr > yr
rmax = torch.where(x_greater, xr, yr)
imax = torch.where(x_greater, xi, yi)
rmin = torch.where(x_greater, yr, xr)
imin = torch.where(x_greater, yi, xi)
return rmax + torch.log(torch.exp(1j * imax) + torch.exp(rmin - rmax + 1j * imin)) |
Partial work from #90847, in the direction of solving #89205. Most of the content is from #90847, but this is only for CPU, so hopefully it does not increase the build time by a lot. tag: @albanD, @malfet Pull Request resolved: #93153 Approved by: https://github.com/malfet, https://github.com/Skylion007
Hopefully fixes #89205. This is another version of #90847 where it was reverted because it increases the compile-time significantly. From my discussion with @ngimel in #93153 (comment), it seems the option of jiterator would be very tricky if not impossible. So what I did was to optimize the compile-time in my computer. To optimize the build time, first I compile the pytorch as a whole, then only change the `LogcumsumexpKernel.cu` file to see how it changes the compile time. Here are my results for the compilation time of only the `LogcumsumexpKernel.cu` file in my computer: - Original version (without any complex implementations): 56s (about 1 minute) - The previous PR (#90847): 13m 57s (about 14 minutes) - This PR: 3m 35s (about 3.5 minutes) If the previous PR increases the build time by 30 mins in pytorch's computer, then this PR reduces the increment of build time to about 6 mins. Hopefully this is an acceptable level of build-time increase. What I did was (sorted by how significant it reduces the build time from the most significant one): - Substituting `log(x)` to `log1p(x - 1)`. This is applied in the infinite case, so we don't really care about precision. - Implementing complex exponential manually tag: @malfet, @albanD Pull Request resolved: #94310 Approved by: https://github.com/Skylion007, https://github.com/malfet
Another PR towards solving #89205.
What's in this PR:
logcumsumexpfor complex numbers in CPU & CUDAlogcumsumexpfor complex numberslogcumsumexpfor complex numbersWhat's missing:
logcumsumexp(it complaintsRuntimeError: logcumsumexp does not support automatic differentiation for outputs with complex dtype.and I don't know how to solve the error and I don't know where to put the test for the backward computation). If possible, I'd like this to be done in this PR.It's really tricky to handle the edge cases here (i.e. the ones involving
inf), but I've tried my best to put some comments explaining the reasonings of my decisions in this PR.cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10