Skip to content

[MPS] [Sparse] unique_dim and sparse broadcast#163694

Closed
Isalia20 wants to merge 6 commits intopytorch:mainfrom
Isalia20:unique-dim-sparse-broadcast
Closed

[MPS] [Sparse] unique_dim and sparse broadcast#163694
Isalia20 wants to merge 6 commits intopytorch:mainfrom
Isalia20:unique-dim-sparse-broadcast

Conversation

@Isalia20
Copy link
Collaborator

@Isalia20 Isalia20 commented Sep 23, 2025

Implements unique_dim, sparse broadcast ops and adds dtypes for mps for tests where we expect to fail, otherwise they would always fail due to being run in double precision

cc @kulinseth @malfet @DenisVieriu97 @jhavukainen

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163694

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 72fd152 with merge base e671dcc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: mps Release notes category label Sep 23, 2025
@Isalia20 Isalia20 added module: mps Related to Apple Metal Performance Shaders framework topic: improvements topic category ciflow/trunk Trigger trunk jobs on your pull request labels Sep 23, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 23, 2025

To add the ciflow label ciflow/trunk please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Sep 23, 2025
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@jbschlosser jbschlosser added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 24, 2025

@expectedFailureMPS
@coalescedonoff
@dtypes(torch.double)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normal version doesn't support torch float32? Or is it just that's tested by optest somewhere else?

Copy link
Collaborator Author

@Isalia20 Isalia20 Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most of the sparse functions on other devices(CPU/CUDA) are all tested in double, not sure why. Maybe it's due to gradcheck being imprecise in float32

self.assertEqual(self.safeToDense(res), self.safeToDense(true_result))

@coalescedonoff
@expectedFailureMPS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove expectedFailureMPS wrappers from these tests

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? We expect that test to fail

# check_autograd(x, y)

@coalescedonoff
@expectedFailureMPS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bunch of these need to be removed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think no? why should we remove

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it's no longer an expectedFailure on MPS?

Copy link
Collaborator Author

@Isalia20 Isalia20 Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We expect this to fail. I added dtypesIfMPS to the test so expectFailureMPS triggers only when there’s an unexpected success. Without this, it would always fail because all tests using this decorator run in torch.float64, which always errors on MPS regardless of whether the op is implemented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good change, but do you mind submitting this one as part of separate PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed it for this PR, will submit in a separate one

# check_autograd(x, y)

@coalescedonoff
@expectedFailureMPS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good change, but do you mind submitting this one as part of separate PR?

}

Tensor perm;
for (int64_t c = cols - 1; c >= 0; --c) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
for (int64_t c = cols - 1; c >= 0; --c) {
for (auto c = cols - 1; c >= 0; --c) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

if (perm.defined()) {
keys = keys.index_select(0, perm);
}
Tensor idx = argsort(keys, /*dim=*/0, /*descending=*/false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Tensor idx = argsort(keys, /*dim=*/0, /*descending=*/false);
const auto idx = argsort(keys, /*dim=*/0, /*descending=*/false);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

}

static Tensor lexsort_rows_perm_mps(const Tensor& mat_2d) {
const auto rows = mat_2d.size(0), cols = mat_2d.size(1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't c++17 allows something like

Suggested change
const auto rows = mat_2d.size(0), cols = mat_2d.size(1);
const auto [rows, cols] = mat_2d.sizes();

Copy link
Collaborator Author

@Isalia20 Isalia20 Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not supported from .sizes() I think, getting error when trying to compile:

/pytorch/aten/src/ATen/native/mps/operations/Unique.mm:320:15: error: cannot decompose private member 'Data' of 'c10::ArrayRef<long long>'
    320 |   const auto [rows, cols] = mat_2d.sizes();

Comment on lines 328 to 332
if (perm.defined()) {
keys = keys.index_select(0, perm);
}
Tensor idx = argsort(keys, /*dim=*/0, /*descending=*/false);
perm = perm.defined() ? perm.index_select(0, idx) : std::move(idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (perm.defined()) {
keys = keys.index_select(0, perm);
}
Tensor idx = argsort(keys, /*dim=*/0, /*descending=*/false);
perm = perm.defined() ? perm.index_select(0, idx) : std::move(idx);
if (!perm.defined()) {
perm = std::move(keys);
continue;
}
keys = keys.index_select(0, perm);
const auto idx = argsort(keys, /*dim=*/0, /*descending=*/false);
perm = perm.index_select(0, idx);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrote in a simpler way

auto output = at::empty(sizes, self.options());
auto inverse_indices = at::empty({0}, self.options().dtype(kLong));
auto counts = at::empty({0}, self.options().dtype(kLong));
return std::make_tuple(output, inverse_indices, counts);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
return std::make_tuple(output, inverse_indices, counts);
return {output, inverse_indices, counts};

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@Isalia20 Isalia20 added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 25, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Sep 25, 2025
@Isalia20 Isalia20 added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 26, 2025
@malfet
Copy link
Contributor

malfet commented Sep 26, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jainapurva pushed a commit that referenced this pull request Sep 29, 2025
Implements unique_dim, sparse broadcast ops and adds dtypes for mps for tests where we expect to fail, otherwise they would always fail due to being run in double precision

Pull Request resolved: #163694
Approved by: https://github.com/malfet
maggiemoss pushed a commit to maggiemoss/pytorch that referenced this pull request Sep 29, 2025
Implements unique_dim, sparse broadcast ops and adds dtypes for mps for tests where we expect to fail, otherwise they would always fail due to being run in double precision

Pull Request resolved: pytorch#163694
Approved by: https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: mps Related to Apple Metal Performance Shaders framework open source release notes: mps Release notes category topic: improvements topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants