Skip to content

Conversation

@mfkasim1
Copy link
Contributor

@mfkasim1 mfkasim1 commented Dec 14, 2022

Another PR towards solving #89205.
What's in this PR:

  • The implementation of forward logcumsumexp for complex numbers in CPU & CUDA
  • The tests on forward call of logcumsumexp for complex numbers
  • The implementation of backward logcumsumexp for complex numbers

What's missing:

  • The test on backward gradient of logcumsumexp (it complaints RuntimeError: logcumsumexp does not support automatic differentiation for outputs with complex dtype. and I don't know how to solve the error and I don't know where to put the test for the backward computation). If possible, I'd like this to be done in this PR.

It's really tricky to handle the edge cases here (i.e. the ones involving inf), but I've tried my best to put some comments explaining the reasonings of my decisions in this PR.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 14, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90847

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 Failures

As of commit d28c203:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Dec 14, 2022
@soulitzer soulitzer removed their request for review December 14, 2022 18:21
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 15, 2022
@mfkasim1
Copy link
Contributor Author

@albanD could you please help me how to enable this function to support automatic differentiation for complex dtype?
I keep getting the error: RuntimeError: logcumsumexp does not support automatic differentiation for outputs with complex dtype. and I've searched the files containing the word logcumsumexp but couldn't find any file that I might need to change?

@albanD
Copy link
Collaborator

albanD commented Dec 26, 2022

Sure!
Just add it to

GRADIENT_IMPLEMENTED_FOR_COMPLEX = {

This is still an allow-list as most ops are not supported and we want to be careful when new ops are added as it is easy to forget to test the complex case.

@albanD
Copy link
Collaborator

albanD commented Jan 4, 2023

The errors seems to be precision issues with the gradients that you compute right?

@mfkasim1
Copy link
Contributor Author

mfkasim1 commented Jan 5, 2023

@albanD yes, but it seems there are other unrelated errors which I don't know how to get rid of.
On the side of the gradient, it seems I'm missing .conj() in various places.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks quite good! Only small comments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to define this as a ref for the OpInfo and it will check that the ref match for all simple inputs without the need for this custom code. Or there are more cases tested here that cannot be tested via the OpInfo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found several challenges in using OpInfo (I probably don't understand OpInfo enough to fully utliize it) for this complex logcumsumexp function:

  1. To compare 2 outputs of logcumsumexp functions, the imaginary part needs to be standardized to be within (0, 2 pi) or (-pi, pi). This is because log(r*e^{i t}) = log|r| + i * (t + 2 pi * n).
  2. scipy.logsumexp gives some confusing answers involving inf case (see my comment around line 590 of this file)

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. Sounds good!

@mfkasim1
Copy link
Contributor Author

Thanks for the approve, @albanD! However, the last time the windows instance run, it still produces an error on the test. I can't reproduce this error on my machine and somehow the error is only raised on windows instance (not other instance). It's related to an edge case involving inf and nan. If possible, I would like to resolve this before getting merged.

@ngimel
Copy link
Collaborator

ngimel commented Jan 11, 2023

math libraries are often different on windows and produce different results in edge cases, I think you can just skip windows test (we have quite a few such skips already)

@mfkasim1
Copy link
Contributor Author

Thanks, @ngimel! I'll push an update that skips the test for windows.

@mfkasim1
Copy link
Contributor Author

I think I'm done with the inf-nan bug on windows. If any of you would like to comment on anything on this PR, please let me know.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me

@mfkasim1
Copy link
Contributor Author

What's next for this?

@albanD
Copy link
Collaborator

albanD commented Jan 17, 2023

Ho sorry, once the PR is approved, you can ask the bot to merge it yourself :)

@pytorchbot -h

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 17, 2023

PyTorchBot Help

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci} ...

In order to invoke the bot on your PR, include a line that starts with
@pytorchbot anywhere in a comment. That line will form the command; no
multi-line commands are allowed. 

Example:
    Some extra context, blah blah, wow this PR looks awesome

    @pytorchbot merge

optional arguments:
  -h, --help            Show this help message and exit.

command:
  {merge,revert,rebase,label,drci}
    merge               Merge a PR
    revert              Revert a PR
    rebase              Rebase a PR
    label               Add label to a PR
    drci                Update Dr. CI

Merge

usage: @pytorchbot merge [-g | -f MESSAGE | -l] [-r [{viable/strict,master}]]

Merge an accepted PR, subject to the rules in .github/merge_rules.json.
By default, this will wait for all required checks (lint, pull) to succeed before merging.

optional arguments:
  -g, --green           Merge when all status checks running on the PR pass. To add status checks, use labels like `ciflow/trunk`.
  -f MESSAGE, --force MESSAGE
                        Merge without checking anything. This requires a reason for auditting purpose, for example:
                        @pytorchbot merge -f 'Minor update to fix lint. Expecting all PR tests to pass'
  -l, --land-checks     [Deprecated - your PR instead now gets the `ciflow/trunk` label on approval] Merge with land time checks. This will create a new branch with your changes rebased on viable/strict and run a majority of trunk tests _before_ landing to increase trunk reliability and decrease risk of revert. The tests added are: pull, Lint and trunk. Note that periodic is excluded.
  -r [{viable/strict,master}], --rebase [{viable/strict,master}]
                        Rebase the PR to re run checks before merging.  Accepts viable/strict or master as branch options and will default to viable/strict if not specified.

Revert

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Revert a merged PR. This requires that you are a Meta employee.

Example:
  @pytorchbot revert -m="This is breaking tests on trunk. hud.pytorch.org/" -c=nosignal

optional arguments:
  -m MESSAGE, --message MESSAGE
                        The reason you are reverting, will be put in the commit message. Must be longer than 3 words.
  -c {nosignal,ignoredsignal,landrace,weird,ghfirst}, --classification {nosignal,ignoredsignal,landrace,weird,ghfirst}
                        A machine-friendly classification of the revert reason.

Rebase

usage: @pytorchbot rebase [-s | -b BRANCH]

Rebase a PR. Rebasing defaults to the stable viable/strict branch of pytorch.
You must have write permissions to the repo to rebase a PR.

optional arguments:
  -s, --stable          [DEPRECATED] Rebase onto viable/strict
  -b BRANCH, --branch BRANCH
                        Branch you would like to rebase to

Label

usage: @pytorchbot label labels [labels ...]

Adds label to a PR

positional arguments:
  labels  Labels to add to given Pull Request

Dr CI

usage: @pytorchbot drci

Update Dr. CI. Updates the Dr. CI comment on the PR in case it's gotten out of sync with actual CI results.

@mfkasim1
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 18, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again. You can rebase by leaving the following comment on this PR:
@pytorchbot rebase

Details for Dev Infra team Raised by workflow job

@mfkasim1
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased clse1 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout clse1 && git pull --rebase)

@mfkasim1
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

@mfkasim1
Copy link
Contributor Author

@pytorchbot merge -f "unrelated error"

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 20, 2023

You are not authorized to force merges to this repository. Please use the regular @pytorchmergebot merge command instead

@mfkasim1
Copy link
Contributor Author

@albanD I think I need your help in merging this PR. The error seems to be a continuation of #92626

@albanD
Copy link
Collaborator

albanD commented Jan 20, 2023

@pytorchbot merge -f "Flaky CI"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet
Copy link
Contributor

malfet commented Jan 24, 2023

This PR increased nightly build time by 30 min, trying a few remedies.

@vors
Copy link
Contributor

vors commented Jan 24, 2023

We also observe this code timing out our build

@ngimel
Copy link
Collaborator

ngimel commented Jan 24, 2023

@mfkasim1 is performance of this operation important? Such increase in build time for a niche op is not great.

@malfet
Copy link
Contributor

malfet commented Jan 24, 2023

@pytorchbot revert -m "Reverting to decrease build time, let's discuss the alternatives here" -c weird

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@mfkasim1 your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Jan 24, 2023
This reverts commit 6498512.

Reverted #90847 on behalf of https://github.com/malfet due to Reverting to decrease build time, let's discuss the alternatives here
@mfkasim1
Copy link
Contributor Author

@mfkasim1 is performance of this operation important? Such increase in build time for a niche op is not great.

Yes, this is the bottleneck operation in my research.
What could I do to make the build time shorter?

@malfet
Copy link
Contributor

malfet commented Jan 24, 2023

Yes, this is the bottleneck operation in my research. What could I do to make the build time shorter?

Can you give some quantifiable numbers? I.e. how bad the perf would be if it's a composite op? Also, can we make it jiteratable for just complex numbers? Or keep out of core?

@mfkasim1
Copy link
Contributor Author

mfkasim1 commented Jan 24, 2023

I tried implementing it with real logcumsumexp and basically the complex version needs 4 calls to the real logcumsumexp plus many other ops (see the code below).
I'm not familiar with jiterator etc. Could you give me some pointers about it? I think initially I tried to use jiterator for this, but I encountered a lot of errors during the build, so I gave up.

def _logcumsumexp(z: torch.Tensor, dim: int) -> torch.Tensor:
    # complex type
    q = torch.real(z)
    k = torch.imag(z)
    a = _logcumsum_aexp(torch.cos(k), q, dim=dim)
    b = _logcumsum_aexp(torch.sin(k), q, dim=dim)
    c = _log_add_exp(a, b + 0.5j * np.pi)
    return c

def _logcumsum_aexp(a: torch.Tensor, b: torch.Tensor, dim: int) -> torch.Tensor:
    # log(cumsum(a * exp(b))), a & b are real, but the returned values are complex
    log_a_pos = torch.log(torch.clamp(a, min=torch.finfo(a.dtype).tiny))
    log_a_neg = torch.log(torch.clamp(-a, min=torch.finfo(a.dtype).tiny))
    lcse_pos = torch.logcumsumexp(b + log_a_pos, dim=dim)
    lcse_neg = torch.logcumsumexp(b + log_a_neg, dim=dim)
    return _log_add_exp(lcse_pos, lcse_neg + 1j * np.pi)

def _log_add_exp(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # complex x & y
    xr = torch.real(x)
    xi = torch.imag(x) if torch.is_complex(x) else torch.zeros_like(x)
    yr = torch.real(y)
    yi = torch.imag(y) if torch.is_complex(y) else torch.zeros_like(y)
    x_greater = xr > yr
    rmax = torch.where(x_greater, xr, yr)
    imax = torch.where(x_greater, xi, yi)
    rmin = torch.where(x_greater, yr, xr)
    imin = torch.where(x_greater, yi, xi)
    return rmax + torch.log(torch.exp(1j * imax) + torch.exp(rmin - rmax + 1j * imin))

@mfkasim1 mfkasim1 mentioned this pull request Jan 27, 2023
pytorchmergebot pushed a commit that referenced this pull request Jan 27, 2023
Partial work from #90847, in the direction of solving #89205.
Most of the content is from #90847, but this is only for CPU, so hopefully it does not increase the build time by a lot.

tag: @albanD, @malfet

Pull Request resolved: #93153
Approved by: https://github.com/malfet, https://github.com/Skylion007
pytorchmergebot pushed a commit that referenced this pull request Feb 13, 2023
Hopefully fixes #89205.
This is another version of #90847 where it was reverted because it increases the compile-time significantly.
From my discussion with @ngimel in #93153 (comment), it seems the option of jiterator would be very tricky if not impossible.
So what I did was to optimize the compile-time in my computer.

To optimize the build time, first I compile the pytorch as a whole, then only change the `LogcumsumexpKernel.cu` file to see how it changes the compile time.
Here are my results for the compilation time of only the `LogcumsumexpKernel.cu` file in my computer:

- Original version (without any complex implementations): 56s (about 1 minute)
- The previous PR (#90847): 13m 57s (about 14 minutes)
- This PR: 3m 35s (about 3.5 minutes)

If the previous PR increases the build time by 30 mins in pytorch's computer, then this PR reduces the increment of build time to about 6 mins. Hopefully this is an acceptable level of build-time increase.

What I did was (sorted by how significant it reduces the build time from the most significant one):

- Substituting `log(x)` to `log1p(x - 1)`. This is applied in the infinite case, so we don't really care about precision.
- Implementing complex exponential manually

tag: @malfet, @albanD
Pull Request resolved: #94310
Approved by: https://github.com/Skylion007, https://github.com/malfet
@lezcano lezcano changed the title Logcumsumexp for complex in CPU and CUDA Add CPU/CUDA support to torch.logcumsumexp Feb 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) open source Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants