Skip to content

Conversation

@IvanYashchuk
Copy link
Collaborator

Previous PR with the same content: #69752. Opening a new PR by request: #69752 (comment).


Previously for single input matrix A and batched matrix B, matrix A was expanded and cloned before computing the LU decomposition and solving the linear system.

With this PR the LU decomposition is computed once for a single matrix and then expanded&cloned if required by a backend library call for the linear system solving.

Here's a basic comparison:

# BEFORE THE PR
In [1]: import torch
In [2]: a = torch.randn(256, 256)
In [3]: b = torch.randn(1024, 256, 2)
In [4]: %%timeit
   ...: torch.linalg.solve(a, b)
   ...:
   ...:
329 ms ± 17.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# WITH THIS PR
In [1]: import torch
In [2]: a = torch.randn(256, 256)
In [3]: b = torch.randn(1024, 256, 2)
In [4]: %%timeit
   ...: torch.linalg.solve(a, b)
   ...:
   ...:
21.4 ms ± 23 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Fixes #71406, fixes #71610

@IvanYashchuk IvanYashchuk added module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul ciflow/all labels Jan 25, 2022
@IvanYashchuk IvanYashchuk requested a review from mruberry January 25, 2022 08:30
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 25, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/IvanYashchuk/pytorch/blob/077d7a35d4a626817cac23c0d239720c1351d644/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/cuda
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk ✅ triggered
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk ✅ triggered
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk ✅ triggered
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-libtorch-linux-xenial-cuda11.1-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck ✅ triggered
periodic-linux-xenial-cuda11.1-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win ✅ triggered
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default 🚫 skipped
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default 🚫 skipped
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default 🚫 skipped
linux-binary-manywheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default 🚫 skipped
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk, ciflow/xla 🚫 skipped
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/linux, ciflow/rocm, ciflow/trunk 🚫 skipped
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk 🚫 skipped
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk 🚫 skipped
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk 🚫 skipped
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk 🚫 skipped
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk 🚫 skipped
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk 🚫 skipped
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk 🚫 skipped
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk 🚫 skipped
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win 🚫 skipped

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jan 25, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 262fe91 (more details on the Dr. CI page):


  • 2/2 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build Test tools / test (1/1)

Step: "Test tools" (full log | diagnosis details | 🔁 rerun)

2022-02-18T09:54:30.2862874Z AssertionError: Li...44s 0.312s', '2021-02-10 11:13:[363 chars]94s']
2022-02-18T09:54:30.2858971Z +  'pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_test 407.616s',
2022-02-18T09:54:30.2859353Z +  '2021-01-06 20:58:28Z fcb69d2e '
2022-02-18T09:54:30.2859766Z +  'pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test 287.044s']
2022-02-18T09:54:30.2859959Z 
2022-02-18T09:54:30.2860087Z ======================================================================
2022-02-18T09:54:30.2860446Z FAIL: test_help_examples (test_test_history.TestTestHistory) (i=2)
2022-02-18T09:54:30.2860952Z ----------------------------------------------------------------------
2022-02-18T09:54:30.2861301Z Traceback (most recent call last):
2022-02-18T09:54:30.2861788Z   File "/home/runner/work/pytorch/pytorch/tools/test/test_test_history.py", line 70, in test_help_examples
2022-02-18T09:54:30.2862186Z     self.assertEqual(actual, expected)
2022-02-18T09:54:30.2862874Z AssertionError: Lists differ: ['202[21 chars]78395', '2021-02-10 11:13:34Z 594a66d7', '2021[220 chars]5f1'] != ['202[21 chars]78395    0.644s    0.312s', '2021-02-10 11:13:[363 chars]94s']
2022-02-18T09:54:30.2863162Z 
2022-02-18T09:54:30.2863277Z First differing element 0:
2022-02-18T09:54:30.2863785Z '2021-02-10 12:18:50Z 3cf78395'
2022-02-18T09:54:30.2864168Z '2021-02-10 12:18:50Z 3cf78395    0.644s    0.312s'
2022-02-18T09:54:30.2864340Z 
2022-02-18T09:54:30.2864507Z - ['2021-02-10 12:18:50Z 3cf78395',
2022-02-18T09:54:30.2864882Z + ['2021-02-10 12:18:50Z 3cf78395    0.644s    0.312s',
2022-02-18T09:54:30.2865164Z ?                                ++++++++++++++++++++
2022-02-18T09:54:30.2865384Z 
2022-02-18T09:54:30.2865568Z -  '2021-02-10 11:13:34Z 594a66d7',

1 failure not recognized by patterns:

Job Step Action
GitHub Actions ios-12-5-1-arm64-custom-ops / build Unknown 🔁 rerun

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments, but none of them needs to be addressed.

return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous();
}

class BroadcastLinearIndices {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this could be used in cusolver's gesvd, as we do a very similar thing there. Not for this PR ofc.

auto b_stride = matrixStride(b);
auto lu_stride = matrixStride(lu);
auto pivots_stride = pivots_cpu.size(-1);
auto lu_stride = lu.dim() > 2 ? lu.stride(-3) : 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this way to get the "matrix stride". This is a great idea!

Comment on lines 2894 to 2895
TORCH_INTERNAL_ASSERT(batchCount(b) == batchCount(lu), "batch_size of b and lu must be the same");
TORCH_INTERNAL_ASSERT(batchCount(lu) == batchCount(pivots.unsqueeze(-1)), "batch_size of lu and pivots must be the same");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just another note. I just realised that we can implement a faster batchCount doing t.dim() > 2 ? t.numel() / t.stride(-3) : 1 when we have a batch of column or row-major matrices.

c10::MaybeOwned<Tensor> maybe_expand_lu(const Tensor& b, const Tensor& lu) {
if (batchCount(b) != batchCount(lu)) {
IntArrayRef b_batch_size(b.sizes().data(), b.dim() - 2);
std::vector<int64_t> expand_size = b_batch_size.vec();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit. DimVector is better suited for these use cases.

@IvanYashchuk
Copy link
Collaborator Author

CI fails with test_linalg_lstsq_cpu_float64

AssertionError: False is not true : Tensors failed to compare as equal!With rtol=1e-05 and atol=1e-05, found 241 element(s) (out of 256) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.6182836257459539 (0.9761708184203711 vs. 0.3578871926744173), which occurred at index (5, 5).

Tests pass locally. Maybe it's just bad seed + lstsq test being flaky?

@lezcano
Copy link
Collaborator

lezcano commented Jan 25, 2022

This PR does not touch lstsq, right? #71222 might be related.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamped! Thank you @IvanYashchuk

@facebook-github-bot
Copy link
Contributor

@mruberry has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@mruberry
Copy link
Collaborator

@IvanYashchuk Do you want to address @lezcano's comments and use it as an opportunity to retrigger CI and see if the lstsq error persists? If it does not we should file an issue for it being a flaky test.

@IvanYashchuk
Copy link
Collaborator Author

I replaced std::vector<int64_t> with at::DimVector, thanks to Mario for the suggestion!

Let's see whether that lstsq test passes now.

@IvanYashchuk
Copy link
Collaborator Author

This PR has accumulated conflicts and I need to resolve them.

@IvanYashchuk
Copy link
Collaborator Author

Hey @mruberry, could you please try importing this PR?

@lezcano
Copy link
Collaborator

lezcano commented Mar 7, 2022

Hi @mruberry. Could you import this one? As part of the LU stack that I'm working on, I plan to go and use all the new tools we have to optimise linalg.solve. For me to be able to do that, I need to have this one merged first, otherwise we'll have some spicy merge conflicts.

@facebook-github-bot
Copy link
Contributor

@mruberry has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Mar 15, 2022
…d attempt) (#71756)

Summary:
Previous PR with the same content: #69752. Opening a new PR by request: #69752 (comment).

------

Previously for single input matrix A and batched matrix B, matrix A was expanded and cloned before computing the LU decomposition and solving the linear system.

With this PR the LU decomposition is computed once for a single matrix and then expanded&cloned if required by a backend library call for the linear system solving.

Here's a basic comparison:
```python
# BEFORE THE PR
In [1]: import torch
In [2]: a = torch.randn(256, 256)
In [3]: b = torch.randn(1024, 256, 2)
In [4]: %%timeit
   ...: torch.linalg.solve(a, b)
   ...:
   ...:
329 ms ± 17.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# WITH THIS PR
In [1]: import torch
In [2]: a = torch.randn(256, 256)
In [3]: b = torch.randn(1024, 256, 2)
In [4]: %%timeit
   ...: torch.linalg.solve(a, b)
   ...:
   ...:
21.4 ms ± 23 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```

Fixes #71406, fixes #71610

Pull Request resolved: #71756

Reviewed By: ngimel

Differential Revision: D33771981

Pulled By: mruberry

fbshipit-source-id: 0917ee36a3eb622ff75d54787b1bffe26b41cb4a
@github-actions
Copy link
Contributor

Hey @IvanYashchuk.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source release notes: linalg_frontend release notes category topic: performance topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.linalg.solve INTERNAL ASSERT FAILED torch.linalg.solve updates are causing GPyTorch tests to fail

5 participants