Skip to content

Conversation

@rlucas7
Copy link
Contributor

@rlucas7 rlucas7 commented Jul 29, 2020

Unblocks implementation of #27036. Note that this PR does not fix #{27036}.
Currently QR decomposition only has support for square and tall (a.k.a. skinny) case.
This PR adds functionality for wide A matrix/tensors, includes 3 unit tests for the new case
and restructures the qr_backward method to use the same Walther method as a helper.

cc @albanD @t-vi

I don't have a gpu machine so haven't tested on cuda but everything passes on my local machine in cpu.

The basic idea of the PR is noted in the comments in the Functions.cpp file but I'll note here too for clarity:

let be a matrix and then partition as
and take QR of and call that one
the here from is the same as the from on entire matrix. Then transform with the rotation got from to get now and similarly for the grads of each piece, e.g. if is grad_A then
and and then
and
is the narrow() of grad_R.
is calculated very similar to the original Walther formula (exactly the same in the tall and square cases) but is slightly modified here for wide case matrices.

@dr-ci
Copy link

dr-ci bot commented Jul 29, 2020

💊 CI failures summary and remediations

As of commit 86994fe (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 13 times.

@t-vi
Copy link
Collaborator

t-vi commented Jul 29, 2020

@lucas7 Thank you for working on this! Looks good overall, thank you for throwing in tests, too.

I'll have to look at the maths some more, thank you for the explanation. In the meantime, could you fix the tabs, please?

@albanD
Copy link
Collaborator

albanD commented Jul 29, 2020

Thanks for the PR!

The diff is not great indeed, maybe the tabs. You can replace them and push again to make the linter happy.

Also is it correct that square_deep_case_backward corresponds exactly to the previous function? Or did you change things in there?

@rlucas7
Copy link
Contributor Author

rlucas7 commented Jul 30, 2020

Thanks for the PR!

The diff is not great indeed, maybe the tabs. You can replace them and push again to make the linter happy.

Yeah that's what I get for cutting a PR so late at night. Will remove tabs on update.

Also is it correct that square_deep_case_backward corresponds exactly to the previous function? Or did you change things in there?

Yeah sorry I wasn't communicating clearly, what I meant with the comment is that in the case that the input matrix is tall or square, the calculation is the same. I left a comment inline in the PR where I was referring to, apologies for the confusion.

@rlucas7
Copy link
Contributor Author

rlucas7 commented Jul 30, 2020

@lucas7 Thank you for working on this! Looks good overall, thank you for throwing in tests, too.

I'll have to look at the maths some more, thank you for the explanation. In the meantime, could you fix the tabs, please?

Oops, wrong alias, I'm rlucas7 but whoever lucas7 is, they've got an awesome avatar :)

Is there a way for me to run the linter locally before I push? I hate to waste the resources to test on the CI w/so many builds just for a linter change.

I found this one: https://fossies.org/linux/pytorch/CONTRIBUTING.md#pre-commit-tidylinting-hook not sure if that is up to date instruction.

@rlucas7
Copy link
Contributor Author

rlucas7 commented Jul 30, 2020

I ran flake8 tools/autograd/templates/Functions.cpp resolved all the tabs and other EXXX issues that came up in lines I touched, ≈LIC 1973-LIC 2090. Also did a mv .flake8 .otherfilename and re-ran flake8, the remaining issues on the lines in this file are all E501 but the .flake8 file sets the line line to 120 instead of 80 so should be fine (I then put .flake8 back where git expects :) ). There are still some flake8 errors when I run flake8 tools/autograd/templates/Functions.cpp locally, but these are LIC 1 - LIC 31 which I didn't touch. If you'd like to me to resolve those style/flake8 issues within this PR I can but they aren't related to the changes so I chose to omit unless instructed o/w, hope that is ok.

Also, not sure if you prefer me to do a rebase with changes, I usually wait until all comments addressed to rebase and squash and this one was still addressing the changes to get the builds to green.

@t-vi
Copy link
Collaborator

t-vi commented Jul 30, 2020

@rlucas7 sorry about the username. You want git-clang-format or so for C++, not flake8.
We do squash the PR commits on merging, so that isn't a problem.

@t-vi
Copy link
Collaborator

t-vi commented Jul 30, 2020

Relative to the organisation of the code: I wonder why we could drop the lambda here, it seems to me that we usually just use _qr_backward_square_tall or so for our helper functions in Functions.cpp (e.g. _euclidean_dist_backward). The functions aren't exported, so I don't see the immediate benefit of having the lambda. I think this would make it a bit more clearer what is old and new, too, as the old function could just continue to live on as the helper.
What do you think?

@albanD
Copy link
Collaborator

albanD commented Jul 30, 2020

All lint is fine now :)
I would agree that if we can make the diff cleaner it would be great. If you cannot do it it's fine as well. Will just be a bit longer to review.

Comment on lines 2052 to 2053
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From here down is the control flow for the wide case, which is the new stuff.

Comment on lines 2059 to 2088
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is using the python @ operator for the (Euclidean) inner product. If this is confusing as written I can replace with the equivalent operator for the Tensor class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine

@rlucas7
Copy link
Contributor Author

rlucas7 commented Jul 30, 2020

Relative to the organisation of the code: I wonder why we could drop the lambda here, it seems to me that we usually just use _qr_backward_square_tall or so for our helper functions in Functions.cpp (e.g. _euclidean_dist_backward). The functions aren't exported, so I don't see the immediate benefit of having the lambda.

I was also a bit confused on the choice of a _* method versus a lambda. Unfortunately, the Functions.cpp file isn't internally consistent on this (or at least it has some logic I didn't grasp). There are cases of helper methods like you mention and also examples of lambdas, which are used throughout the matrix autodiff methods. For example,det_backward uses 2 lambdas:
One for the singular case

auto singular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {

And one for the non-singular case
auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {

and a similar structure is followed for the other matrix backprop methods slogdet_backward

Tensor slogdet_backward(const Tensor& grad_logabsdet,

and for logdet_backward
Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& logdet) {

plus those are closer in the Functions.cpp file to where qr_backward sits, which is why I chose the lambda approach ultimately.
If function names aren't exported anywhere then the only benefit would be if the function is used in another function in the Functions.cpp. I didn't see any case of that in the Functions.cpp.

I think this would make it a bit more clearer what is old and new, too, as the old function could just continue to live on as the helper.
What do you think?

I don't think it will make it clearer what is old and what is new. I initially had written the function as a separate function to keep it simpler. IIRC the only changes I made when moving it into the lambda was changing some of the matrix overwrites to reduce memory use and to handle cases for grad_Q.defined() and/or grad_R.defined(). Now those changes weren't in the lambda itself but in the qr_backward control flow for the new wide case, I've left an inline note where that part of the changeset occurs.

I think the confusion is around the way github is striping the lambda refactor in the diff as 3 patches, obfuscating that this is really a lift and shift into a lambda. For this I also left some inline comments to help clarify what's going on in terms of the changes.

Maybe take a look at that and let me know if it's still unclear and you want to pull it out as a separate function?

If you both still prefer it as a separate _* helper after looking at comments I can change it to be written that way instead.

Overall I think it might help for me to enumerate changes that are made to qr_backward in this PR:

  1. refactor original qr_backward calculation into a lambda
  2. add support for qr_backward when self (or A as I call it in lambda) is wide (more columns than rows)
    Note this change includes some overwrites of Tensors to use less memory in the calculations.
  3. additional handling for when the grad_R and/or grad_Q are (or are not) defined.

Also would it help if I shared some notes on the Maths here for the wide case?

Using Github's latex injection is pretty clunky but I could write up the algebra for the wide case in latex and share it via email. Some of the other SciPy devs and I have done that before on some random variate generators and found it helpful to understand why things were coded as they were.

@albanD albanD self-requested a review July 31, 2020 21:03
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 31, 2020
@mruberry mruberry added the module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul label Jul 31, 2020
@rlucas7
Copy link
Contributor Author

rlucas7 commented Aug 16, 2020

Hello, friendly inquiry, anything further needed from my end at this time-any updates?

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey,
Sorry for the delay.
The code change looks good to me. I added just small comments for formatting and comments.

But looking at the test, I think the old specification was not correct and not checking the gradients. It would be nice if you could re-activate that for all the settings and make sure the formula works fine!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this comment refer to? Is it just an artefact or older refactor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

artefact, I agree comment is confusing, I'll remove the line.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the trick with caps to avoid changing the code is ok.
But to avoid any issue with unwanted aliasing, I would prevent any capture from the lambda by just setting [] (or if you need anything, specify it explicitly).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the trick with caps to avoid changing the code is ok.
But to avoid any issue with unwanted aliasing, I would prevent any capture from the lambda by just setting [] (or if you need anything, specify it explicitly).

Got it. I'll remove the captures ampersand, thanks for catching that one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could extract m and n before this to make it slightly more readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good idea, that will make it clearer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, makes it clearer, thanks for that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're missing a sentence that says how to compute grad_y?
Or just remove it as it is mentioned below and move the final grad_A formula 3 lines down?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I'll remove the grad_Y part from the line and make the comment a bit clearer, and shuffle around grad_A formula as you mention

Comment on lines 2059 to 2088
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this does not.
foo = bar does not write into foo inplace. It just frees the old content of foo and put a copy (talking about shared_ptr like object here, so no content copy) of bar into it.
Can you just update the comment?

In general, I would advise against doing inplace changes here as it would most likely prevent double backward from working.
In this case, grad_Y is actually a view of the provided grads that should never be modified inplace.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this does not.
foo = bar does not write into foo inplace. It just frees the old content of foo and put a copy (talking about shared_ptr like object here, so no content copy) of bar into it.
Can you just update the comment?

I will remove the line from the comments.

I think we also want to remove the overwrite, correct?

Thanks for catching that, is this what you're referring to?

Tensor output = has_out ? at::_unsafe_view(at::bmm_out(out, tensor1_expanded, tensor2_expanded), output_shape)
: at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded), output_shape);

IIUC, I will add a grad_V to the control flow and removed the overwriting so it will read:
grad_Y = at::matmul(q, grad_V); on LIC 2087
(changing the grad_Y entries above to read grad_V.

In general, I would advise against doing inplace changes here as it would most likely prevent double backward from working.
In this case, grad_Y is actually a view of the provided grads that should never be modified inplace.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that change looks good.

is this what you're referring to?

In some sense yes. You can use the version where you provide the output but that does not work with autograd!

@rlucas7
Copy link
Contributor Author

rlucas7 commented Aug 21, 2020

Hey,
Sorry for the delay.

No worries @albanD thanks for your time and thorough review.

The code change looks good to me. I added just small comments for formatting and comments.

But looking at the test, I think the old specification was not correct and not checking the gradients. It would be nice if you could re-activate that for all the settings and make sure the formula works fine!

I wasn't clear on this one after doing some repo spelunking, I left a couples links at the inline comment of what I found, doesn't seem that the 5th tuple is used?

The units tests do seem to pass both with and without them though.

I'll put together the full changeset to address comments and update the PR over the weekend.

@nikitaved
Copy link
Collaborator

nikitaved commented Aug 22, 2020

@rlucas7 , just so you know, there is an issue with the current qr_backward implementation described in #42792 which materializes itself when the input matrix is not of full rank.
In simple words, the method referenced assumes R to be of full rank (R is square in the paper), so then R^{-1} does exist. R^{-1} is not being computed explicitly of course, but the system involving R might be inconsistent.

Do you guaranteed that the matrix X_{n,n} is of full rank (If rank(A) = n)? Because if so, then the backward formula from the paper as it is currently implemented should work with no problems, I think.

@nikitaved
Copy link
Collaborator

nikitaved commented Aug 22, 2020

@rlucas7, @alban, the way how X is defined with the assumption of being full-rank screams for the rank-reveling QR. Maybe it makes sense to completely abolish the current QR implementation and substitute it with a version with pivoting. Otherwise it is quite hard to implement backward correctly without knowing which columns of the input are actually linearly-independent.

But yeah, there is an issue, because users will probably just want a simple A = QR instead of having an additional permutation matrix so that AP = QR.

@albanD
Copy link
Collaborator

albanD commented Aug 24, 2020

@nikitaved what about we add a warning to the doc specifying the current limitation of the backward for qr and add this improvement (that has the same limitations).
And in a future PR, after when we update the QR forward, we can update the formula here to perform the right thing?

@nikitaved
Copy link
Collaborator

@albanD , sure, makes sense. I can actually update the documentation.

@rlucas7
Copy link
Contributor Author

rlucas7 commented Aug 26, 2020

@albanD wrote:

@nikitaved what about we add a warning to the doc specifying the current limitation of the backward for qr and add this improvement (that has the same limitations).
And in a future PR, after when we update the QR forward, we can update the formula here to perform the right thing?

@nikitaved wrote:

@albanD , sure, makes sense. I can actually update the documentation.

It seems like the tril solver issue and the dgeqp3 request are separate issues.
I agree the suggestion of @albanD seems like a reasonable path forward.

AFAICS the only change here with a RRQR (on forward pass) would be a right multiply by a permutation matrix P before doing gradient calculations.
Then after the gradients are calculations post multiply the returned grad_A via P^T, the inverse of P, to move the columns back to their original arrangements.
This would be a very small change to the existing codes (once dgeqp3 is brought into pytorch in forward pass).

The use of RRQR would not change the fact that the equation requires rank(A)=k=min(m,n).

@albanD it looks like there are some conflicts on this branch, causing test failures in CI.
Difficult for me to determine, is a git rebase -i master required here to get the CI to green, or is there something more required?

I make this guess from ctrl-f in this file for error:

https://dr.pytorch.org/api/view-log-full?build_id=129136002

seems to indicate merge conflicts, the other windows builds seem to have network connections issues (unrelated I assume).

@albanD
Copy link
Collaborator

albanD commented Aug 26, 2020

Hi,

@nikitaved already sent a PR for the doc so that's good.

Yes, can you rebase on top of master to make CI happy please?

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
Will wait on the updated CI signal before merging.

clean up linter errors

remove overwrites and cleanup comments
@codecov
Copy link

codecov bot commented Aug 27, 2020

Codecov Report

Merging #42216 into master will increase coverage by 0.00%.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master   #42216   +/-   ##
=======================================
  Coverage   69.40%   69.40%           
=======================================
  Files         378      378           
  Lines       46610    46610           
=======================================
+ Hits        32350    32351    +1     
+ Misses      14260    14259    -1     
Impacted Files Coverage Δ
...ch/testing/_internal/common_methods_invocations.py 91.12% <ø> (ø)
torch/testing/_internal/expecttest.py 78.57% <0.00%> (+1.02%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 0bf27d6...86994fe. Read the comment docs.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@albanD merged this pull request in 2bede78.

@dvirginz
Copy link

dvirginz commented Sep 4, 2020

@rlucas7 Thank you very much for the contribution!
Just to make sure, should this make 'lstsq' differentiable?
I'm using the nightly version to torch (1.7.0.dev20200902+cu101), and still having the the derivative for 'lstsq' is not implemented error.
Thanks!

@rlucas7
Copy link
Contributor Author

rlucas7 commented Sep 4, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants