[MPS] Compute offset2bag/bag_size/max_indices in _embedding_bag#163281
[MPS] Compute offset2bag/bag_size/max_indices in _embedding_bag#163281kurtamohler wants to merge 5 commits intogh/kurtamohler/52/basefrom
offset2bag/bag_size/max_indices in _embedding_bag#163281Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163281
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 703a149 with merge base f8f230a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| template <typename T> | ||
| struct ReductionOp<EmbeddingBagMode::MAX, T> { | ||
| inline opmath_t<T> operator()(opmath_t<T> weight_val, opmath_t<T> out_val) { | ||
| return max(weight_val, out_val); |
There was a problem hiding this comment.
Q: Which max are you using there? one from metal:: or one from c10::metal::?
I don't know if embedding bug supposed to carry about NaN, but if it is, make sure to use c10::metal:: wrapper, as regular one will not be able to handle NaN correctly
There was a problem hiding this comment.
Good catch. Looks like the only way to make it match the CPU impl is to use metal::max and also initialize out_val with nan. I'll make those changes and add some nans to the test
There was a problem hiding this comment.
Ok I've updated it to have the same nan behavior as the CPU impl. It doesn't use max at all, and instead uses comparison. Let me know what you think
There was a problem hiding this comment.
By the way, I don't know which one is faster: ternary or max
| thread I& max_idx, | ||
| I weight_idx, | ||
| bool pad) { | ||
| max_idx = (pad || new_out_val == out_val) ? max_idx : weight_idx; |
There was a problem hiding this comment.
I'm not sure if compiler is smart enough, but wouldn't it be better to do something like
| max_idx = (pad || new_out_val == out_val) ? max_idx : weight_idx; | |
| if (!pad && new_out_val != out_val) { | |
| max_idx = weight_idx; | |
| } |
There was a problem hiding this comment.
I'm not sure if the compiler is smart enough to avoid the thread divergence either, but the Metal documentation does recommend avoiding if statements that could potentially cause divergence: link
Apparently, XCode has a thread divergence counter in the profiler, so it would be possible to check this. But I don't have access to a graphical interface on the machine I'm using
There was a problem hiding this comment.
Let me experiment locally and I'll let you know
| with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"): | ||
| helper(22, 0, []) | ||
|
|
||
| # TODO: This test can be removed once the backward pass of embedding_bag is |
There was a problem hiding this comment.
Hmm, what stops us from running existing forward tests from op_info and just add emebdding_bag to GRAD_FAILURES?
There was a problem hiding this comment.
We're already running the torch.nn.functional.embedding_bag opinfo tests, but that function does not return offset2bag, bag_size, and max_indices. There currently is no forward mode opinfo test that checks those
| opmath_t<T> weight_val, | ||
| opmath_t<T> out_val, | ||
| bool is_first) { | ||
| return (is_first || weight_val > out_val) ? weight_val : out_val; |
There was a problem hiding this comment.
This would also be a non-nan preserving, but I guess it's the same behavior as CPU
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ytorch#163281) Part of pytorch#162270 Pull Request resolved: pytorch#163281 Approved by: https://github.com/malfet
…163281) Part of #162270 Pull Request resolved: #163281 Approved by: https://github.com/malfet
Stack from ghstack (oldest at bottom):
offset2bag/bag_size/max_indicesin_embedding_bag#163281Part of #162270
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben