-
Notifications
You must be signed in to change notification settings - Fork 26.3k
fix pca_lowrank memory consumption #40853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 7184385 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
pearu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for this simplification!
pearu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion to use stride tricks to avoid creating possible large temporary tensors.
| C = A.mean(dim=(-2,), keepdim=True) | ||
| return _svd_lowrank(A - C, q, niter=niter, M=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change this code to:
C = A.mean(dim=(-2,), keepdim=True)
M_strides = [C.stride(i) for i in range(len(C.shape))]
M_strides[-2] = 0
M = C.as_strided(A.shape, M_strides)
return _svd_lowrank(A, q, niter=niter, M=M)
that will avoid creating a possibly large A - C.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do that, but I suspect it will just postpone the inevitable - when matmul sees discontiguous arguments, it blow them up. Will have to double check matmul code, but I vaguely remember it from when I looked at it previously.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you ever backprop through pca_lowrank? If not, and A is otherwise not needed, A-C can be computed inplace, avoiding large memory allocation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do that, but I suspect it will just postpone the inevitable - when matmul sees discontiguous arguments, it blow them up
I tested the code locally, all pca tests passed ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, sorry, by "blow up" I meant it will materialize expanded contiguous tensors. So yeah, the tests will pass, but memory usage is possibly still suboptimal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a memory is contiguous, but tensor a is not contiguous. In contiguous tensors, no 2 logically different elements point to the same memory location, whereas in a all elements point to the same memory location.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, the whole point of expand-contiguous conversion is to ensure that tensor data is C-contiguous and can be passed to some low-level routine that is not stride-aware, that is, the low-level routine would access the data memory according to the shape information, not strides. That said, matmul low-level routines ought to be stride-aware (thinking of LAPACK) so that the expand-contiguous conversion is unnecessary. So there could be some ideas for matmul optimization here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hear what you are saying, however, typically matmul (let's talk about 2d tensors for simplicity) will be dispatched to a blas library that implements gemm API . As you can see here, lda/ldb that are providing stride-awareness are explicitly constrained to be positive, not 0. That said, instead of dispatching calls with stride=0 args to blas, doing pointwise multiplication of arguments followed by a reduction and expanding back to the necessary size is likely to be faster. Double points if multiplication could be done on the fly during reduction, so that A*B (pointwise-multiplied) did not have to be materialized, but unfortunately pytorch cannot do that, you'd need keops for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
surprised that blas doesn't support zero stride :( Do we need to write specialized kernels for these cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The joys of 80's libraries for you :-) The primitive that we need here to avoid materialization is "broadcasted multiplication followed by a reduction" #32591 (comment), einsum people have been talking about it for a long time. It sounds easy but, given how different broadcasting patterns can be and how generally hard it is to write an efficient reduction for various access patterns, turns out to be pretty hard, there are libraries that are doing it to varying degrees of success (Keops and cutensor that I know of)
pearu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving the PR as it is because my suggestion was invalid.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Per title, fixes #40768