[MPS] Add linalg.householder_product for MPS#166090
[MPS] Add linalg.householder_product for MPS#166090kurtamohler wants to merge 3 commits intogh/kurtamohler/57/basefrom
linalg.householder_product for MPS#166090Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166090
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4da24cb with merge base 9038a30 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
| threadgroup_barrier(mem_flags::mem_threadgroup); | ||
|
|
||
| T H_prod_0_to_i_rc = | ||
| calc_matmul_rc(H_prod, H, H_stride_r, H_stride_c, m, r, c); |
There was a problem hiding this comment.
At the moment, performance is much worse than that of the CPU impl, except in some cases where the number of batches is greater than the number of A matrix elements times tau vector elements.
The vast majority of runtime is spent in this matrix multiplication. I'm using a naive implementation of matmul, so we should be able to get much better performance if I change it to a tiled matmul. I suppose it should be possible to just reuse the tiled matmul defined earlier in this file, so I will look into that
There was a problem hiding this comment.
I attempted to improve performance (in this branch) by changing the kernel to just generate the householder matrices and use the existing do_metal_bmm for the matrix multiply. It improved performance slightly in some cases, and decreased in others, but overall it didn't make too much of a difference. Maybe the CPU impl isn't actually doing a series of full matrix multiplies and instead uses some simplified formula. I'll have to take a look at the lapack impl. But I guess this is probably somewhat low priority
Fixes pytorch#166089 ghstack-source-id: 1d7b4a8 Pull-Request: pytorch#166090
|
I will follow up with a performance improvement PR once I understand how to do it |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
linalg.householder_productfor MPS #166090Fixes #166089