Skip to content

[MPS] Compute offset2bag/bag_size/max_indices in _embedding_bag#163281

Closed
kurtamohler wants to merge 5 commits intogh/kurtamohler/52/basefrom
gh/kurtamohler/52/head
Closed

[MPS] Compute offset2bag/bag_size/max_indices in _embedding_bag#163281
kurtamohler wants to merge 5 commits intogh/kurtamohler/52/basefrom
gh/kurtamohler/52/head

Conversation

@kurtamohler
Copy link
Collaborator

@kurtamohler kurtamohler commented Sep 18, 2025

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163281

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 703a149 with merge base f8f230a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kurtamohler added a commit that referenced this pull request Sep 18, 2025
@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Sep 18, 2025
[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Sep 18, 2025
template <typename T>
struct ReductionOp<EmbeddingBagMode::MAX, T> {
inline opmath_t<T> operator()(opmath_t<T> weight_val, opmath_t<T> out_val) {
return max(weight_val, out_val);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Which max are you using there? one from metal:: or one from c10::metal::?
I don't know if embedding bug supposed to carry about NaN, but if it is, make sure to use c10::metal:: wrapper, as regular one will not be able to handle NaN correctly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Looks like the only way to make it match the CPU impl is to use metal::max and also initialize out_val with nan. I'll make those changes and add some nans to the test

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I've updated it to have the same nan behavior as the CPU impl. It doesn't use max at all, and instead uses comparison. Let me know what you think

Copy link
Contributor

@malfet malfet Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, I don't know which one is faster: ternary or max

thread I& max_idx,
I weight_idx,
bool pad) {
max_idx = (pad || new_out_val == out_val) ? max_idx : weight_idx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if compiler is smart enough, but wouldn't it be better to do something like

Suggested change
max_idx = (pad || new_out_val == out_val) ? max_idx : weight_idx;
if (!pad && new_out_val != out_val) {
max_idx = weight_idx;
}

Copy link
Collaborator Author

@kurtamohler kurtamohler Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the compiler is smart enough to avoid the thread divergence either, but the Metal documentation does recommend avoiding if statements that could potentially cause divergence: link

Apparently, XCode has a thread divergence counter in the profiler, so it would be possible to check this. But I don't have access to a graphical interface on the machine I'm using

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me experiment locally and I'll let you know

with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
helper(22, 0, [])

# TODO: This test can be removed once the backward pass of embedding_bag is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, what stops us from running existing forward tests from op_info and just add emebdding_bag to GRAD_FAILURES?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're already running the torch.nn.functional.embedding_bag opinfo tests, but that function does not return offset2bag, bag_size, and max_indices. There currently is no forward mode opinfo test that checks those

[ghstack-poisoned]
[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Sep 19, 2025
opmath_t<T> weight_val,
opmath_t<T> out_val,
bool is_first) {
return (is_first || weight_val > out_val) ? weight_val : out_val;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would also be a non-nan preserving, but I guess it's the same behavior as CPU

[ghstack-poisoned]
kurtamohler added a commit that referenced this pull request Sep 19, 2025
@kurtamohler
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 23, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
jainapurva pushed a commit that referenced this pull request Sep 29, 2025
@github-actions github-actions bot deleted the gh/kurtamohler/52/head branch October 24, 2025 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source release notes: mps Release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants