-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add a transform for positive-definite matrices. #76777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit c61fde6e3f (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
lezcano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation makes sense to me. I'll let others review whether this is a reasonable PR in the context of distributions (I am not involved in its maintenance).
|
LGTM, thanks for adding this! cc @nonconvexopt - could you take a look as well? |
@neerajprad Thank you for tagging me. It is good to know how PyTorch is improving. |
59b363f to
26e98be
Compare
|
I think this one is ready to be merged. Would you be able to take another look, @neerajprad? Thank you for the time reviewing my distribution-related PRs. |
26e98be to
efb3359
Compare
efb3359 to
c61fde6
Compare
|
/easycla As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details. This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign. |
|
Hi @neerajprad, I've now signed the CLA. Would you mind taking another look? |
|
@pytorchbot rebase |
lezcano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The linalg part looks good. I'll let the people from distributions review the rest.
|
@pytorchbot successfully started a rebase job. Check the current status here |
|
Successfully rebased |
c61fde6 to
87023fa
Compare
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/76777
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 378a0cd: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@fritzo, @neerajprad PTAL |
fritzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tillahoffmann, everything looks good except my one comment on inheritance. After that's fixed LGTM!
|
Thanks for the review, @fritzo. I couldn't find the comment on inheritance, unfortunately. Which line was that on? |
torch/distributions/transforms.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please avoid inheriting from LowerCholeskyTransform here. Instead either own an instance or use the global instance. I think something like this should work:
class PositiveDefiniteTransform(Transform):
def __init__(self, cache_size=0):
super.__init__(cache_size=cache_size)
self._lower_cholesky = LowerCholeskyTransform(cache_size=cache_size)
def _call(self, x):
x = self._lower_cholesky(x)
return x @ x.mT
def _inverse(self, y):
y = torch.linalg.cholesky(y)
return self._lower_cholesky.inv(y)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just for my understanding, why is this prefered over the current approach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice, downstream users may dispatch on transforms via singledispatch, multipledispatch, or isinstance, which could yield incorrect behavior due to the false match. This could also lead to a missed typechecking error, where a LowerCholeskyTransform is required, a PositiveDefiniteTransform is provided, and the typechecker erroneously thinks that's fine.
In theory, the two classes do not satisfy a subtyping relation, an is-a relationship: PositiveDefiniteTransform is not a special case of LowerCholeskyTransform.
(sorry I had tried to explain in my earlier review, but apparently forgot to save the comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed the inheritance. Overriding __init__ also required overriding with_cache (see below), and I decided to instantiate a new LowerCholeskyTransform for each transform. The overhead seems minimal, but happy to make adjustments as you see fit. I couldn't find the global instance of LowerCholeskyTransform.
pytorch/torch/distributions/transforms.py
Lines 135 to 137 in 777ac63
| if type(self).__init__ is Transform.__init__: | |
| return type(self)(cache_size=cache_size) | |
| raise NotImplementedError("{}.with_cache is not implemented".format(type(self))) |
c91d458 to
7e8d1df
Compare
fritzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for resolving the inheritance issue!
|
@lezcano would you merge this please? (i believe i lost merge privileges) |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: The following mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot rebase |
|
@pytorchbot successfully started a rebase job. Check the current status here |
Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com>
|
Successfully rebased |
7e8d1df to
378a0cd
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The `PositiveDefiniteTransform` is required to transform from an unconstrained space to positive definite matrices, e.g. to support testing the Wishart mode in pytorch#76690. It is a simple extension of the `LowerCholeskyTransform`. I've also added a small test that ensures the generated data belong to the domain of the associated transform. Previously, the data generated for the inverse transform of the `LowerCholeskyTransform` wasn't part of the domain, and the test only passed because the comparison uses `equal_nan=True`. Pull Request resolved: pytorch#76777 Approved by: https://github.com/lezcano, https://github.com/fritzo, https://github.com/soumith
The
PositiveDefiniteTransformis required to transform from an unconstrained space to positive definite matrices, e.g. to support testing the Wishart mode in #76690. It is a simple extension of theLowerCholeskyTransform.I've also added a small test that ensures the generated data belong to the domain of the associated transform. Previously, the data generated for the inverse transform of the
LowerCholeskyTransformwasn't part of the domain, and the test only passed because the comparison usesequal_nan=True.cc @fritzo, @neerajprad